dan92 commited on
Commit
029dbc0
·
verified ·
1 Parent(s): a7d2f0d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +630 -762
main.py CHANGED
@@ -1,762 +1,630 @@
1
- import asyncio
2
- import json
3
- import sys
4
- import uuid
5
- import base64
6
- import re
7
- import os
8
- import argparse
9
- from datetime import datetime, timezone, timedelta
10
- from typing import List, Optional
11
-
12
- import httpx
13
- import uvicorn
14
- from fastapi import (
15
- BackgroundTasks,
16
- FastAPI,
17
- HTTPException,
18
- Request,
19
- Response,
20
- status,
21
- )
22
- from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
23
- from fastapi.middleware.cors import CORSMiddleware
24
- from fastapi.staticfiles import StaticFiles
25
-
26
- from bearer_token import BearerTokenGenerator
27
-
28
- # 模型列表(根据需求,可自行调整)
29
- MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
30
-
31
- # 默认端口
32
- INITIAL_PORT = 3000
33
-
34
- # 外部API的URL
35
- EXTERNAL_API_URL = "https://api.chaton.ai/chats/stream"
36
-
37
- # 定义 images 目录
38
- IMAGES_DIR = "images"
39
-
40
- # 确保 images 目录存在
41
- def ensure_images_dir_exists(directory: str = IMAGES_DIR):
42
- try:
43
- os.makedirs(directory, exist_ok=True)
44
- print(f"Directory '{directory}' is ready.")
45
- except Exception as e:
46
- print(f"Failed to create directory '{directory}': {e}")
47
- sys.exit(1)
48
-
49
- # 在挂载静态文件之前确保 images 目录存在
50
- ensure_images_dir_exists()
51
-
52
- # 初始化FastAPI应用
53
- app = FastAPI()
54
-
55
- # 挂载静态文件路由以提供 images 目录的内容
56
- app.mount("/images", StaticFiles(directory=IMAGES_DIR), name="images")
57
-
58
- # 添加CORS中间件
59
- app.add_middleware(
60
- CORSMiddleware,
61
- allow_origins=["*"], # 允许所有来源
62
- allow_credentials=True,
63
- allow_methods=["GET", "POST", "OPTIONS"], # 允许GET, POST, OPTIONS方法
64
- allow_headers=["Content-Type", "Authorization"], # 允许的头部
65
- )
66
-
67
- # 辅助函数
68
- def send_error_response(message: str, status_code: int = 400):
69
- """构建错误响应,并确保包含CORS头"""
70
- error_json = {"error": message}
71
- headers = {
72
- "Access-Control-Allow-Origin": "*",
73
- "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
74
- "Access-Control-Allow-Headers": "Content-Type, Authorization",
75
- }
76
- return JSONResponse(status_code=status_code, content=error_json, headers=headers)
77
-
78
- def extract_path_from_markdown(markdown: str) -> Optional[str]:
79
- """
80
- 提取 Markdown 图片链接中的路径,匹配以 https://spc.unk/ 开头的 URL
81
- """
82
- pattern = re.compile(r'!\[.*?\]\(https://spc\.unk/(.*?)\)')
83
- match = pattern.search(markdown)
84
- if match:
85
- return match.group(1)
86
- return None
87
-
88
- async def fetch_get_url_from_storage(storage_url: str) -> Optional[str]:
89
- """
90
- storage URL 获取 JSON 并提取 getUrl
91
- """
92
- async with httpx.AsyncClient() as client:
93
- try:
94
- response = await client.get(storage_url)
95
- if response.status_code != 200:
96
- print(f"获取 storage URL 失败,状态码: {response.status_code}")
97
- return None
98
- json_response = response.json()
99
- return json_response.get("getUrl")
100
- except Exception as e:
101
- print(f"Error fetching getUrl from storage: {e}")
102
- return None
103
-
104
- async def download_image(image_url: str) -> Optional[bytes]:
105
- """
106
- 下载图像
107
- """
108
- async with httpx.AsyncClient() as client:
109
- try:
110
- response = await client.get(image_url)
111
- if response.status_code == 200:
112
- return response.content
113
- else:
114
- print(f"下载图像失败,状态码: {response.status_code}")
115
- return None
116
- except Exception as e:
117
- print(f"Error downloading image: {e}")
118
- return None
119
-
120
- def cleanup_images(images_dir: str = IMAGES_DIR, age_seconds: int = 60):
121
- """
122
- 清理 images 目录中创建时间超过指定秒数的图片
123
- """
124
- now = datetime.now(timezone.utc)
125
- cutoff_time = now - timedelta(seconds=age_seconds)
126
-
127
- if not os.path.exists(images_dir):
128
- return
129
-
130
- for filename in os.listdir(images_dir):
131
- file_path = os.path.join(images_dir, filename)
132
- if os.path.isfile(file_path):
133
- try:
134
- file_creation_time = datetime.fromtimestamp(os.path.getctime(file_path), timezone.utc)
135
- if file_creation_time < cutoff_time:
136
- os.remove(file_path)
137
- print(f"已删除旧图片: {filename}")
138
- except Exception as e:
139
- print(f"无法删除文件 {filename}: {e}")
140
-
141
- def save_base64_image(base64_str: str, images_dir: str = IMAGES_DIR) -> str:
142
- """
143
- 将Base64编码的图片保存到images目录,返回文件名
144
- """
145
- # 先清理1分钟前的所有图片
146
- cleanup_images(images_dir, age_seconds=60)
147
-
148
- try:
149
- image_data = base64.b64decode(base64_str)
150
- except base64.binascii.Error as e:
151
- print(f"Base64解码失��: {e}")
152
- raise ValueError("Invalid base64 image data")
153
-
154
- filename = f"{uuid.uuid4()}.png" # 默认保存为png格式
155
- file_path = os.path.join(images_dir, filename)
156
- try:
157
- with open(file_path, "wb") as f:
158
- f.write(image_data)
159
- print(f"保存图片: {filename}")
160
- except Exception as e:
161
- print(f"保存图片失败: {e}")
162
- raise
163
-
164
- return filename
165
-
166
- def is_base64_image(url: str) -> bool:
167
- """
168
- 判断URL是否为Base64编码的图片
169
- """
170
- return url.startswith("data:image/")
171
-
172
- # 根路径GET请求处理
173
- @app.get("/", response_class=HTMLResponse)
174
- async def read_root():
175
- """返回欢迎页面"""
176
- html_content = """
177
- <html>
178
- <head>
179
- <title>Welcome to API</title>
180
- </head>
181
- <body>
182
- <h1>Welcome to API</h1>
183
- <p>You can send messages to the model and receive responses.</p>
184
- </body>
185
- </html>
186
- """
187
- return HTMLResponse(content=html_content, status_code=200)
188
-
189
- # 聊天完成处理(保留原有逻辑,未修改)
190
- @app.post("/ai/v1/chat/completions")
191
- @app.post("/v1/chat/completions")
192
- async def chat_completions(request: Request, background_tasks: BackgroundTasks):
193
- """
194
- 处理聊天完成请求
195
- """
196
- try:
197
- request_body = await request.json()
198
- except json.JSONDecodeError:
199
- raise HTTPException(status_code=400, detail="Invalid JSON")
200
-
201
- # 打印接收到的请求
202
- print("Received Completion JSON:", json.dumps(request_body, ensure_ascii=False))
203
-
204
- # 处理消息内容
205
- messages = request_body.get("messages", [])
206
- temperature = request_body.get("temperature", 1.0)
207
- #top_p = request_body.get("top_p", 1.0)
208
- max_tokens = request_body.get("max_tokens", 8000)
209
- model = request_body.get("model", "gpt-4o")
210
- is_stream = request_body.get("stream", False) # 获取 stream 字段
211
-
212
- # 验证模型
213
- if model not in MODELS:
214
- raise HTTPException(status_code=400, detail=f"无效的 model: {model}. 可用的模型有: {', '.join(MODELS)}")
215
-
216
- has_image = False
217
- has_text = False
218
-
219
- # 清理和提取消息内容
220
- cleaned_messages = []
221
- for message in messages:
222
- content = message.get("content", "")
223
- if isinstance(content, list):
224
- text_parts = []
225
- images = []
226
- for item in content:
227
- if "text" in item:
228
- text_parts.append(item.get("text", ""))
229
- elif "image_url" in item:
230
- has_image = True
231
- image_info = item.get("image_url", {})
232
- url = image_info.get("url", "")
233
- if is_base64_image(url):
234
- # 解码并保存图片
235
- try:
236
- base64_str = url.split(",")[1]
237
- filename = save_base64_image(base64_str)
238
- base_url = app.state.base_url
239
- image_url = f"{base_url}/images/{filename}"
240
- images.append({"data": image_url})
241
- except (IndexError, ValueError) as e:
242
- print(f"处理Base64图片失败: {e}")
243
- continue
244
- else:
245
- images.append({"data": url})
246
- extracted_content = " ".join(text_parts).strip()
247
- if extracted_content:
248
- has_text = True
249
- message["content"] = extracted_content
250
- if images:
251
- message["images"] = images
252
- cleaned_messages.append(message)
253
- print("Extracted:", extracted_content)
254
- else:
255
- if images:
256
- has_image = True
257
- message["content"] = ""
258
- message["images"] = images
259
- cleaned_messages.append(message)
260
- print("Extracted image only.")
261
- else:
262
- print("Deleted message with empty content.")
263
- elif isinstance(content, str):
264
- content_str = content.strip()
265
- if content_str:
266
- has_text = True
267
- new_message = {"role": message.get("role"), "content": content_str}
268
- cleaned_messages.append(new_message)
269
- print("Retained content:", content_str)
270
- else:
271
- print("Deleted message with empty content.")
272
- else:
273
- print("Deleted non-expected type of content message.")
274
-
275
- if not cleaned_messages:
276
- raise HTTPException(status_code=400, detail="所有消息的内容均为空。")
277
-
278
- # 构建新的请求JSON
279
- new_request_json = {
280
- "function_image_gen": False,
281
- "function_web_search": True,
282
- "max_tokens": max_tokens,
283
- "model": model,
284
- "source": "chat/pro",
285
- "temperature": temperature,
286
- #"top_p": top_p,
287
- "messages": cleaned_messages,
288
- }
289
-
290
- modified_request_body = json.dumps(new_request_json, ensure_ascii=False)
291
- print("Modified Request JSON:", modified_request_body)
292
-
293
- # 获取Bearer Token
294
- tmp_token = BearerTokenGenerator.get_bearer(modified_request_body)
295
- if not tmp_token:
296
- raise HTTPException(status_code=500, detail="无法生成 Bearer Token")
297
-
298
- bearer_token, formatted_date = tmp_token
299
-
300
- headers = {
301
- "Date": formatted_date,
302
- "Client-time-zone": "-05:00",
303
- "Authorization": bearer_token,
304
- "User-Agent": "ChatOn_Android/1.55.488",
305
- "Accept-Language": "en-US",
306
- "X-Cl-Options": "hb",
307
- "Content-Type": "application/json; charset=UTF-8",
308
- }
309
-
310
- if is_stream:
311
- import uuid
312
- from datetime import datetime, timezone
313
-
314
- # 定义 should_filter_out 函数
315
- def should_filter_out(json_data):
316
- if 'ping' in json_data:
317
- return True
318
- if 'data' in json_data:
319
- data = json_data['data']
320
- if 'analytics' in data:
321
- return True
322
- if 'operation' in data and 'message' in data:
323
- return True
324
- return False
325
-
326
- # 定义 generate_id 函数
327
- def generate_id():
328
- return uuid.uuid4().hex[:24]
329
-
330
- # 流式响应处理
331
- async def event_generator():
332
- async with httpx.AsyncClient(timeout=None) as client_stream:
333
- try:
334
- async with client_stream.stream("POST", EXTERNAL_API_URL, headers=headers, content=modified_request_body) as streamed_response:
335
- async for line in streamed_response.aiter_lines():
336
- if line.startswith("data: "):
337
- data = line[6:].strip()
338
- if data == "[DONE]":
339
- # 通知客户端流结束
340
- yield "data: [DONE]\n\n"
341
- break
342
- try:
343
- sse_json = json.loads(data)
344
-
345
- # 判断是否需要过滤
346
- if should_filter_out(sse_json):
347
- continue
348
-
349
- # 处理包含 web sources 的消息
350
- if 'data' in sse_json and 'web' in sse_json['data']:
351
- web_data = sse_json['data']['web']
352
- if 'sources' in web_data:
353
- sources = web_data['sources']
354
- urls_list = []
355
- for source in sources:
356
- if 'url' in source:
357
- urls_list.append(source['url'])
358
- urls_content = '\n\n'.join(urls_list)
359
- print(f"从 API 接收到的内容: {urls_content}")
360
- # 构造新的 SSE 消息,填入 content 字段
361
- new_sse_json = {
362
- "id": generate_id(),
363
- "object": "chat.completion.chunk",
364
- "created": int(datetime.now(timezone.utc).timestamp()),
365
- "model": sse_json.get("model", "gpt-4o"),
366
- "choices": [
367
- {
368
- "delta": {"content": "\n" + urls_content + "\n"},
369
- "index": 0,
370
- "finish_reason": None
371
- }
372
- ]
373
- }
374
- new_sse_line = f"data: {json.dumps(new_sse_json, ensure_ascii=False)}\n\n"
375
- yield new_sse_line
376
- else:
377
- # 尝试打印内容
378
- if 'choices' in sse_json:
379
- for choice in sse_json['choices']:
380
- delta = choice.get('delta', {})
381
- content = delta.get('content')
382
- if content:
383
- print(content, end='')
384
- # 直接转发其他消息
385
- yield f"data: {data}\n\n"
386
- except json.JSONDecodeError as e:
387
- print(f"JSON解析错误: {e}")
388
- continue
389
- else:
390
- # 忽略不以 "data: " 开头的行
391
- continue
392
- except httpx.RequestError as exc:
393
- print(f"外部API请求失败: {exc}")
394
- yield f"data: {{\"error\": \"外部API请求失败: {str(exc)}\"}}\n\n"
395
-
396
- return StreamingResponse(
397
- event_generator(),
398
- media_type="text/event-stream",
399
- headers={
400
- "Cache-Control": "no-cache",
401
- "Connection": "keep-alive",
402
- # CORS头已通过中间件处理,无需在这里重复添加
403
- },
404
- )
405
- else:
406
- # 非流式响应处理
407
- async with httpx.AsyncClient(timeout=None) as client:
408
- try:
409
- response = await client.post(
410
- EXTERNAL_API_URL,
411
- headers=headers,
412
- content=modified_request_body,
413
- timeout=None
414
- )
415
-
416
- if response.status_code != 200:
417
- raise HTTPException(
418
- status_code=response.status_code,
419
- detail=f"API 错误: {response.status_code}",
420
- )
421
-
422
- sse_lines = response.text.splitlines()
423
- content_builder = ""
424
- images_urls = []
425
-
426
- for line in sse_lines:
427
- if line.startswith("data: "):
428
- data = line[6:].strip()
429
- if data == "[DONE]":
430
- break
431
- try:
432
- sse_json = json.loads(data)
433
- if "choices" in sse_json:
434
- for choice in sse_json["choices"]:
435
- if "delta" in choice:
436
- delta = choice["delta"]
437
- if "content" in delta:
438
- content_builder += delta["content"]
439
- except json.JSONDecodeError:
440
- print("JSON解析错误")
441
- continue
442
-
443
- openai_response = {
444
- "id": f"chatcmpl-{uuid.uuid4()}",
445
- "object": "chat.completion",
446
- "created": int(datetime.now(timezone.utc).timestamp()),
447
- "model": model, # 使用用户传入的model
448
- "choices": [
449
- {
450
- "index": 0,
451
- "message": {
452
- "role": "assistant",
453
- "content": content_builder,
454
- },
455
- "finish_reason": "stop",
456
- }
457
- ],
458
- }
459
-
460
- # 处理图片(如果有)
461
- if has_image:
462
- images = []
463
- for message in cleaned_messages:
464
- if "images" in message:
465
- for img in message["images"]:
466
- images.append({"data": img["data"]})
467
- openai_response["choices"][0]["message"]["images"] = images
468
-
469
- return JSONResponse(content=openai_response, status_code=200)
470
- except httpx.RequestError as exc:
471
- raise HTTPException(status_code=500, detail=f"请求失败: {str(exc)}")
472
- except Exception as exc:
473
- raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(exc)}")
474
-
475
- # 图像生成处理
476
- @app.post("/ai/v1/images/generations")
477
- @app.post("/v1/images/generations")
478
- async def images_generations(request: Request):
479
- """
480
- 处理图像生成请求
481
- """
482
- try:
483
- request_body = await request.json()
484
- except json.JSONDecodeError:
485
- return send_error_response("Invalid JSON", status_code=400)
486
-
487
- print("Received Image Generations JSON:", json.dumps(request_body, ensure_ascii=False))
488
-
489
- # 验证必需的字段
490
- if "prompt" not in request_body:
491
- return send_error_response("缺少必需的字段: prompt", status_code=400)
492
-
493
- user_prompt = request_body.get("prompt", "").strip()
494
- response_format = request_body.get("response_format", "").strip().lower()
495
- model = request_body.get("model", "gpt-4o")
496
- n = request_body.get("n", 1) # 生成图像数量,默认1
497
- size = request_body.get("size", "1024x1024") # 图片尺寸,默认1024x1024
498
-
499
- is_base64_response = response_format == "b64_json"
500
-
501
- if not user_prompt:
502
- return send_error_response("Prompt 不能为空。", status_code=400)
503
-
504
- print(f"Prompt: {user_prompt}")
505
- print(f"Response Format: {response_format}")
506
- print(f"Number of images to generate (n): {n}")
507
- print(f"Size: {size}")
508
-
509
- # 设置最大尝试次数为 2 * n
510
- max_attempts = 2 * n
511
- print(f"Max Attempts: {max_attempts}")
512
-
513
- # 初始化用于存储多个 URL 的线程安全列表
514
- final_download_urls: List[str] = []
515
-
516
- async def attempt_generate_image(attempt: int) -> Optional[str]:
517
- """
518
- 尝试生成单张图像,带有重试机制。
519
- """
520
- try:
521
- # 构建新的 TextToImage JSON 请求体
522
- text_to_image_json = {
523
- "function_image_gen": True,
524
- "function_web_search": True,
525
- "image_aspect_ratio": "1:1", # 图片比例可选:1:1/9:19/16:9/4:3
526
- "image_style": "photographic", # 固定 image_style,可根据需要调整
527
- "max_tokens": 8000,
528
- "messages": [
529
- {
530
- "content": "You are a helpful artist, please draw a picture. Based on imagination, draw a picture with user message.",
531
- "role": "system"
532
- },
533
- {
534
- "content": "Draw: " + user_prompt,
535
- "role": "user"
536
- }
537
- ],
538
- "model": "gpt-4o",
539
- "source": "chat/free" # 固定 source
540
- }
541
-
542
- modified_request_body = json.dumps(text_to_image_json, ensure_ascii=False)
543
- print(f"Attempt {attempt} - Modified Request JSON: {modified_request_body}")
544
-
545
- # 获取Bearer Token
546
- tmp_token = BearerTokenGenerator.get_bearer(modified_request_body, path="/chats/stream")
547
- if not tmp_token:
548
- print(f"Attempt {attempt} - 无法生成 Bearer Token")
549
- return None
550
-
551
- bearer_token, formatted_date = tmp_token
552
-
553
- headers = {
554
- "Date": formatted_date,
555
- "Client-time-zone": "-05:00",
556
- "Authorization": bearer_token,
557
- "User-Agent": "ChatOn_Android/1.53.502",
558
- "Accept-Language": "en-US",
559
- "X-Cl-Options": "hb",
560
- "Content-Type": "application/json; charset=UTF-8",
561
- }
562
-
563
- async with httpx.AsyncClient(timeout=None) as client:
564
- response = await client.post(
565
- EXTERNAL_API_URL,
566
- headers=headers,
567
- content=modified_request_body,
568
- timeout=None
569
- )
570
-
571
- if response.status_code != 200:
572
- print(f"Attempt {attempt} - API 错误: {response.status_code}")
573
- return None
574
-
575
- # 读取 SSE 流并提取图像URL
576
- sse_lines = response.text.splitlines()
577
- image_markdown = ""
578
-
579
- for line in sse_lines:
580
- if line.startswith("data: "):
581
- data = line[6:].strip()
582
- if data == "[DONE]":
583
- break
584
- try:
585
- sse_json = json.loads(data)
586
- if "choices" in sse_json:
587
- for choice in sse_json["choices"]:
588
- delta = choice.get("delta", {})
589
- content = delta.get("content")
590
- if content:
591
- image_markdown += content
592
- except json.JSONDecodeError:
593
- print(f"Attempt {attempt} - JSON解析错误")
594
- continue
595
-
596
- # 检查Markdown文本是否为空
597
- if not image_markdown:
598
- print(f"Attempt {attempt} - 无法从 SSE 流中构建图像 Markdown。")
599
- return None
600
-
601
- # 从Markdown中提取图像路径
602
- extracted_path = extract_path_from_markdown(image_markdown)
603
- if not extracted_path:
604
- print(f"Attempt {attempt} - 无法从 Markdown 中提取路径。")
605
- return None
606
-
607
- print(f"Attempt {attempt} - 提取的路径: {extracted_path}")
608
-
609
- # 拼接最终的存储URL
610
- storage_url = f"https://api.chaton.ai/storage/{extracted_path}"
611
- print(f"Attempt {attempt} - 存储URL: {storage_url}")
612
-
613
- # 获取最终下载URL
614
- final_download_url = await fetch_get_url_from_storage(storage_url)
615
- if not final_download_url:
616
- print(f"Attempt {attempt} - 无法从 storage URL 获取最终下载链接。")
617
- return None
618
-
619
- print(f"Attempt {attempt} - Final Download URL: {final_download_url}")
620
-
621
- return final_download_url
622
- except Exception as e:
623
- print(f"Attempt {attempt} - 处理响应时发生错误: {e}")
624
- return None
625
-
626
- # 定义一个异步任务池,限制并发数量
627
- semaphore = asyncio.Semaphore(10) # 限制同时进行的任务数为10
628
-
629
- async def generate_with_retries(attempt: int) -> Optional[str]:
630
- async with semaphore:
631
- return await attempt_generate_image(attempt)
632
-
633
- # 开始尝试生成图像
634
- for attempt in range(1, max_attempts + 1):
635
- needed = n - len(final_download_urls)
636
- if needed <= 0:
637
- break
638
-
639
- print(f"Attempt {attempt} - 需要生成的图像数量: {needed}")
640
-
641
- # 创建多个任务同时生成所需数量的图像
642
- tasks = [asyncio.create_task(generate_with_retries(attempt)) for _ in range(needed)]
643
-
644
- # 等待所有任务完成
645
- results = await asyncio.gather(*tasks, return_exceptions=True)
646
-
647
- for result in results:
648
- if isinstance(result, Exception):
649
- print(f"Attempt {attempt} - 任务发生异常: {result}")
650
- continue
651
- if result:
652
- final_download_urls.append(result)
653
- print(f"Attempt {attempt} - 收集到下载链接: {result}")
654
-
655
- # 检查是否已经收集到足够的下载链接
656
- if len(final_download_urls) >= n:
657
- break
658
-
659
- # 检查是否收集到足够的链接
660
- if len(final_download_urls) < n:
661
- print("已达到最大尝试次数,仍未收集到足够数量的下载链接。")
662
- return send_error_response("无法生成足够数量的图像。", status_code=500)
663
-
664
- # 根据 response_format 返回相应的响应
665
- data_array = []
666
-
667
- if is_base64_response:
668
- for download_url in final_download_urls[:n]:
669
- try:
670
- image_bytes = await download_image(download_url)
671
- if not image_bytes:
672
- print(f"无法从 URL 下载图像: {download_url}")
673
- continue
674
-
675
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
676
- data_array.append({
677
- "b64_json": image_base64
678
- })
679
- except Exception as e:
680
- print(f"处理图像时发生错误: {e}")
681
- continue
682
- else:
683
- for download_url in final_download_urls[:n]:
684
- data_array.append({
685
- "url": download_url
686
- })
687
-
688
- # 如果收集的 URL 数量不足 n,则通过复制现有的 URL 来填充
689
- while len(data_array) < n and len(data_array) > 0:
690
- for item in data_array.copy():
691
- if len(data_array) >= n:
692
- break
693
- data_array.append(item)
694
-
695
- # 构建最终响应
696
- response_json = {
697
- "created": int(datetime.now(timezone.utc).timestamp()),
698
- "data": data_array
699
- }
700
-
701
- # 如果data_array为空,返回错误
702
- if not data_array:
703
- return send_error_response("无法生成图像。", status_code=500)
704
-
705
- # 返回响应
706
- return JSONResponse(content=response_json, status_code=200)
707
-
708
- @app.get("/ai/v1/models", response_class=JSONResponse)
709
- @app.get("/v1/models", response_class=JSONResponse)
710
- async def get_models():
711
- models_data = {
712
- "object": "list",
713
- "data": [
714
- {"id": "gpt-4o", "object": "model"},
715
- {"id": "gpt-4o-mini", "object": "model"},
716
- {"id": "claude-3-5-sonnet", "object": "model"},
717
- {"id": "claude", "object": "model"}
718
- ]
719
- }
720
- return models_data
721
-
722
- async def get_available_port(start_port: int = INITIAL_PORT, end_port: int = 65535) -> int:
723
- """查找可用的端口号"""
724
- for port in range(start_port, end_port + 1):
725
- try:
726
- # 尝试绑定端口
727
- server = await asyncio.start_server(lambda r, w: None, host="0.0.0.0", port=port)
728
- server.close()
729
- await server.wait_closed()
730
- return port
731
- except OSError:
732
- continue
733
- raise RuntimeError(f"No available ports between {start_port} and {end_port}")
734
-
735
- def main():
736
- parser = argparse.ArgumentParser(description="启动ChatOn API服务器")
737
- parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
738
- parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
739
- args = parser.parse_args()
740
-
741
- base_url = args.base_url
742
- port = args.port
743
-
744
- # ���置 FastAPI 应用的 state
745
- app.state.base_url = base_url
746
-
747
- print(f"Starting server on port {port} with base_url: {base_url}")
748
-
749
- # 检查端口可用性
750
- try:
751
- port = asyncio.run(get_available_port(start_port=port))
752
- except RuntimeError as e:
753
- print(e)
754
- return
755
-
756
- print(f"Server running on available port: {port}")
757
-
758
- # 运行FastAPI应用
759
- uvicorn.run(app, host="0.0.0.0", port=port)
760
-
761
- if __name__ == "__main__":
762
- main()
 
1
+ import asyncio
2
+ import json
3
+ import sys
4
+ import uuid
5
+ import base64
6
+ import re
7
+ import os
8
+ import argparse
9
+ import time
10
+ from datetime import datetime, timezone
11
+ from typing import List, Optional
12
+
13
+ import httpx
14
+ import uvicorn
15
+ from fastapi import (
16
+ BackgroundTasks,
17
+ FastAPI,
18
+ HTTPException,
19
+ Request,
20
+ Response,
21
+ status,
22
+ )
23
+ from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
24
+ from fastapi.middleware.cors import CORSMiddleware
25
+ from fastapi.staticfiles import StaticFiles
26
+
27
+ from bearer_token import BearerTokenGenerator
28
+
29
+ from fastapi import Depends, HTTPException, Security
30
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
31
+
32
+ # 模型列表
33
+ MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
34
+
35
+ # 默认端口
36
+ INITIAL_PORT = 3000
37
+
38
+ # 外部API的URL
39
+ EXTERNAL_API_URL = "https://api.chaton.ai/chats/stream"
40
+
41
+ # 初始化FastAPI应用
42
+ app = FastAPI()
43
+
44
+ # 添加CORS中间件
45
+ app.add_middleware(
46
+ CORSMiddleware,
47
+ allow_origins=["*"], # 允许所有来源
48
+ allow_credentials=True,
49
+ allow_methods=["GET", "POST", "OPTIONS"], # 允许GET, POST, OPTIONS方法
50
+ allow_headers=["Content-Type", "Authorization"], # 允许的头部
51
+ )
52
+
53
+ # 挂载静态文件路由以提供 images 目录的内容
54
+ app.mount("/images", StaticFiles(directory="images"), name="images")
55
+
56
+ # 辅助函数
57
+ def send_error_response(message: str, status_code: int = 400):
58
+ """构建错误响应,并确保包含CORS头"""
59
+ error_json = {"error": message}
60
+ headers = {
61
+ "Access-Control-Allow-Origin": "*",
62
+ "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
63
+ "Access-Control-Allow-Headers": "Content-Type, Authorization",
64
+ }
65
+ return JSONResponse(status_code=status_code, content=error_json, headers=headers)
66
+
67
+ def extract_path_from_markdown(markdown: str) -> Optional[str]:
68
+ """
69
+ 提取 Markdown 图片链接中的路径,匹配以 https://spc.unk/ 开头的 URL
70
+ """
71
+ pattern = re.compile(r'!\[.*?\]\(https://spc\.unk/(.*?)\)')
72
+ match = pattern.search(markdown)
73
+ if match:
74
+ return match.group(1)
75
+ return None
76
+
77
+ async def fetch_get_url_from_storage(storage_url: str) -> Optional[str]:
78
+ """
79
+ 从 storage URL 获取 JSON 并提取 getUrl
80
+ """
81
+ async with httpx.AsyncClient() as client:
82
+ try:
83
+ response = await client.get(storage_url)
84
+ if response.status_code != 200:
85
+ print(f"获取 storage URL 失败,状态码: {response.status_code}")
86
+ return None
87
+ json_response = response.json()
88
+ return json_response.get("getUrl")
89
+ except Exception as e:
90
+ print(f"Error fetching getUrl from storage: {e}")
91
+ return None
92
+
93
+ async def download_image(image_url: str) -> Optional[bytes]:
94
+ """
95
+ 下载图像
96
+ """
97
+ async with httpx.AsyncClient() as client:
98
+ try:
99
+ response = await client.get(image_url)
100
+ if response.status_code == 200:
101
+ return response.content
102
+ else:
103
+ print(f"下载图像失败,状态码: {response.status_code}")
104
+ return None
105
+ except Exception as e:
106
+ print(f"Error downloading image: {e}")
107
+ return None
108
+
109
+ def save_base64_image(base64_str: str, images_dir: str = "images") -> str:
110
+ """
111
+ 将Base64编码的图片保存到images目录,返回文件名
112
+ """
113
+ if not os.path.exists(images_dir):
114
+ os.makedirs(images_dir)
115
+ image_data = base64.b64decode(base64_str)
116
+ filename = f"{uuid.uuid4()}.png" # 默认保存为png格式
117
+ file_path = os.path.join(images_dir, filename)
118
+ with open(file_path, "wb") as f:
119
+ f.write(image_data)
120
+ return filename
121
+
122
+ def is_base64_image(url: str) -> bool:
123
+ """
124
+ 判断URL是否为Base64编码的图片
125
+ """
126
+ return url.startswith("data:image/")
127
+
128
+ # 添加 HTTPBearer 实例
129
+ security = HTTPBearer()
130
+
131
+ # 添加 API_KEY 验证函数
132
+ def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
133
+ api_key = os.environ.get("API_KEY")
134
+ if api_key is None:
135
+ raise HTTPException(status_code=500, detail="API_KEY not set in environment variables")
136
+ if credentials.credentials != api_key:
137
+ raise HTTPException(status_code=401, detail="Invalid API key")
138
+ return credentials.credentials
139
+
140
+ # 根路径GET请求处理
141
+ @app.get("/")
142
+ async def root():
143
+ return JSONResponse(content={
144
+ "service": "AI Chat Completion Proxy",
145
+ "usage": {
146
+ "endpoint": "/ai/v1/chat/completions",
147
+ "method": "POST",
148
+ "headers": {
149
+ "Content-Type": "application/json",
150
+ "Authorization": "Bearer YOUR_API_KEY"
151
+ },
152
+ "body": {
153
+ "model": "One of: " + ", ".join(MODELS),
154
+ "messages": [
155
+ {"role": "system", "content": "You are a helpful assistant."},
156
+ {"role": "user", "content": "Hello, who are you?"}
157
+ ],
158
+ "stream": False,
159
+ "temperature": 0.7,
160
+ "max_tokens": 8000
161
+ }
162
+ },
163
+ "availableModels": MODELS,
164
+ "endpoints": {
165
+ "/ai/v1/chat/completions": "Chat completion endpoint",
166
+ "/ai/v1/images/generations": "Image generation endpoint",
167
+ "/ai/v1/models": "List available models"
168
+ },
169
+ "note": "Replace YOUR_API_KEY with your actual API key."
170
+ })
171
+
172
+ # 返回模型列表
173
+ @app.get("/ai/v1/models")
174
+ async def list_models(api_key: str = Depends(verify_api_key)):
175
+ """返回可用模型列表。"""
176
+ models = [
177
+ {
178
+ "id": model,
179
+ "object": "model",
180
+ "created": int(time.time()),
181
+ "owned_by": "chaton",
182
+ "permission": [],
183
+ "root": model,
184
+ "parent": None,
185
+ } for model in MODELS
186
+ ]
187
+ return JSONResponse(content={
188
+ "object": "list",
189
+ "data": models
190
+ })
191
+
192
+ # 聊天完成处理
193
+ @app.post("/ai/v1/chat/completions")
194
+ async def chat_completions(request: Request, background_tasks: BackgroundTasks, api_key: str = Depends(verify_api_key)):
195
+ """
196
+ 处理聊天完成请求
197
+ """
198
+ try:
199
+ request_body = await request.json()
200
+ except json.JSONDecodeError:
201
+ raise HTTPException(status_code=400, detail="Invalid JSON")
202
+
203
+ # 打印接收到的请求
204
+ print("Received Completion JSON:", json.dumps(request_body, ensure_ascii=False))
205
+
206
+ # 处理消息内容
207
+ messages = request_body.get("messages", [])
208
+ temperature = request_body.get("temperature", 1.0)
209
+ top_p = request_body.get("top_p", 1.0)
210
+ max_tokens = request_body.get("max_tokens", 8000)
211
+ model = request_body.get("model", "gpt-4o")
212
+ is_stream = request_body.get("stream", False) # 获取 stream 字段
213
+
214
+ has_image = False
215
+ has_text = False
216
+
217
+ # 清理和提取消息内容
218
+ cleaned_messages = []
219
+ for message in messages:
220
+ content = message.get("content", "")
221
+ if isinstance(content, list):
222
+ text_parts = []
223
+ images = []
224
+ for item in content:
225
+ if "text" in item:
226
+ text_parts.append(item.get("text", ""))
227
+ elif "image_url" in item:
228
+ has_image = True
229
+ image_info = item.get("image_url", {})
230
+ url = image_info.get("url", "")
231
+ if is_base64_image(url):
232
+ # 解码并保存图片
233
+ base64_str = url.split(",")[1]
234
+ filename = save_base64_image(base64_str)
235
+ base_url = app.state.base_url
236
+ image_url = f"{base_url}/images/{filename}"
237
+ images.append({"data": image_url})
238
+ else:
239
+ images.append({"data": url})
240
+ extracted_content = " ".join(text_parts).strip()
241
+ if extracted_content:
242
+ has_text = True
243
+ message["content"] = extracted_content
244
+ if images:
245
+ message["images"] = images
246
+ cleaned_messages.append(message)
247
+ print("Extracted:", extracted_content)
248
+ else:
249
+ if images:
250
+ has_image = True
251
+ message["content"] = ""
252
+ message["images"] = images
253
+ cleaned_messages.append(message)
254
+ print("Extracted image only.")
255
+ else:
256
+ print("Deleted message with empty content.")
257
+ elif isinstance(content, str):
258
+ content_str = content.strip()
259
+ if content_str:
260
+ has_text = True
261
+ message["content"] = content_str
262
+ cleaned_messages.append(message)
263
+ print("Retained content:", content_str)
264
+ else:
265
+ print("Deleted message with empty content.")
266
+ else:
267
+ print("Deleted non-expected type of content message.")
268
+
269
+ if not cleaned_messages:
270
+ raise HTTPException(status_code=400, detail="所有消息的内容均为空。")
271
+
272
+ # 验证模型
273
+ if model not in MODELS:
274
+ model = "gpt-4o"
275
+
276
+ # 构建新的请求JSON
277
+ new_request_json = {
278
+ "function_image_gen": False,
279
+ "function_web_search": True,
280
+ "max_tokens": max_tokens,
281
+ "model": model,
282
+ "source": "chat/free",
283
+ "temperature": temperature,
284
+ "top_p": top_p,
285
+ "messages": cleaned_messages,
286
+ }
287
+
288
+ modified_request_body = json.dumps(new_request_json, ensure_ascii=False)
289
+ print("Modified Request JSON:", modified_request_body)
290
+
291
+ # 获取Bearer Token
292
+ tmp_token = BearerTokenGenerator.get_bearer(modified_request_body)
293
+ if not tmp_token:
294
+ raise HTTPException(status_code=500, detail="无法生成 Bearer Token")
295
+
296
+ bearer_token, formatted_date = tmp_token
297
+
298
+ headers = {
299
+ "Date": formatted_date,
300
+ "Client-time-zone": "-05:00",
301
+ "Authorization": bearer_token,
302
+ "User-Agent": "ChatOn_Android/1.56.483",
303
+ "Accept-Language": "en-US",
304
+ "X-Cl-Options": "hb",
305
+ "Content-Type": "application/json; charset=UTF-8",
306
+ }
307
+
308
+ if is_stream:
309
+ # 流式响应处理
310
+ async def event_generator():
311
+ async with httpx.AsyncClient(timeout=None) as client_stream:
312
+ try:
313
+ async with client_stream.stream("POST", EXTERNAL_API_URL, headers=headers, content=modified_request_body) as streamed_response:
314
+ async for line in streamed_response.aiter_lines():
315
+ if line.startswith("data: "):
316
+ data = line[6:].strip()
317
+ if data == "[DONE]":
318
+ # 通知客户端流结束
319
+ yield "data: [DONE]\n\n"
320
+ break
321
+ try:
322
+ sse_json = json.loads(data)
323
+ if "choices" in sse_json:
324
+ for choice in sse_json["choices"]:
325
+ delta = choice.get("delta", {})
326
+ content = delta.get("content")
327
+ if content:
328
+ new_sse_json = {
329
+ "choices": [
330
+ {
331
+ "index": choice.get("index", 0),
332
+ "delta": {"content": content},
333
+ }
334
+ ],
335
+ "created": sse_json.get(
336
+ "created", int(datetime.now(timezone.utc).timestamp())
337
+ ),
338
+ "id": sse_json.get(
339
+ "id", str(uuid.uuid4())
340
+ ),
341
+ "model": sse_json.get("model", "gpt-4o"),
342
+ "system_fingerprint": f"fp_{uuid.uuid4().hex[:12]}",
343
+ }
344
+ new_sse_line = f"data: {json.dumps(new_sse_json, ensure_ascii=False)}\n\n"
345
+ yield new_sse_line
346
+ except json.JSONDecodeError:
347
+ print("JSON解析错误")
348
+ continue
349
+ except httpx.RequestError as exc:
350
+ print(f"外部API请求失败: {exc}")
351
+ yield f"data: {{\"error\": \"外部API请求失败: {str(exc)}\"}}\n\n"
352
+
353
+ return StreamingResponse(
354
+ event_generator(),
355
+ media_type="text/event-stream",
356
+ headers={
357
+ "Cache-Control": "no-cache",
358
+ "Connection": "keep-alive",
359
+ # CORS头已通过中间件处理,无需在这里重复添加
360
+ },
361
+ )
362
+ else:
363
+ # 非流式响应处理
364
+ async with httpx.AsyncClient(timeout=None) as client:
365
+ try:
366
+ response = await client.post(
367
+ EXTERNAL_API_URL,
368
+ headers=headers,
369
+ content=modified_request_body,
370
+ timeout=None
371
+ )
372
+
373
+ if response.status_code != 200:
374
+ raise HTTPException(
375
+ status_code=response.status_code,
376
+ detail=f"API 错误: {response.status_code}",
377
+ )
378
+
379
+ sse_lines = response.text.splitlines()
380
+ content_builder = ""
381
+ images_urls = []
382
+
383
+ for line in sse_lines:
384
+ if line.startswith("data: "):
385
+ data = line[6:].strip()
386
+ if data == "[DONE]":
387
+ break
388
+ try:
389
+ sse_json = json.loads(data)
390
+ if "choices" in sse_json:
391
+ for choice in sse_json["choices"]:
392
+ if "delta" in choice:
393
+ delta = choice["delta"]
394
+ if "content" in delta:
395
+ content_builder += delta["content"]
396
+ except json.JSONDecodeError:
397
+ print("JSON解析错误")
398
+ continue
399
+
400
+ openai_response = {
401
+ "id": f"chatcmpl-{uuid.uuid4()}",
402
+ "object": "chat.completion",
403
+ "created": int(datetime.now(timezone.utc).timestamp()),
404
+ "model": model,
405
+ "choices": [
406
+ {
407
+ "index": 0,
408
+ "message": {
409
+ "role": "assistant",
410
+ "content": content_builder,
411
+ },
412
+ "finish_reason": "stop",
413
+ }
414
+ ],
415
+ }
416
+
417
+ # 处理图片(如果有)
418
+ if has_image:
419
+ images = []
420
+ for message in cleaned_messages:
421
+ if "images" in message:
422
+ for img in message["images"]:
423
+ images.append({"data": img["data"]})
424
+ openai_response["choices"][0]["message"]["images"] = images
425
+
426
+ return JSONResponse(content=openai_response, status_code=200)
427
+ except httpx.RequestError as exc:
428
+ raise HTTPException(status_code=500, detail=f"请求失败: {str(exc)}")
429
+ except Exception as exc:
430
+ raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(exc)}")
431
+
432
+ # 图像生成处理
433
+ @app.post("/ai/v1/images/generations")
434
+ async def images_generations(request: Request, api_key: str = Depends(verify_api_key)):
435
+ """
436
+ 处理图像生成请求
437
+ """
438
+ try:
439
+ request_body = await request.json()
440
+ except json.JSONDecodeError:
441
+ return send_error_response("Invalid JSON", status_code=400)
442
+
443
+ print("Received Image Generations JSON:", json.dumps(request_body, ensure_ascii=False))
444
+
445
+ # 验证必需的字段
446
+ if "prompt" not in request_body:
447
+ return send_error_response("缺少必需的字段: prompt", status_code=400)
448
+
449
+ user_prompt = request_body.get("prompt", "").strip()
450
+ response_format = request_body.get("response_format", "b64_json").strip()
451
+
452
+ if not user_prompt:
453
+ return send_error_response("Prompt 不能为空。", status_code=400)
454
+
455
+ print(f"Prompt: {user_prompt}")
456
+
457
+ # 构建新的 TextToImage JSON 请求体
458
+ text_to_image_json = {
459
+ "function_image_gen": True,
460
+ "function_web_search": True,
461
+ "image_aspect_ratio": "1:1",
462
+ "image_style": "photographic", # 暂时固定 image_style
463
+ "max_tokens": 8000,
464
+ "messages": [
465
+ {
466
+ "content": "You are a helpful artist, please based on imagination draw a picture.",
467
+ "role": "system"
468
+ },
469
+ {
470
+ "content": "Draw: " + user_prompt,
471
+ "role": "user"
472
+ }
473
+ ],
474
+ "model": "gpt-4o", # 固定 model,只能gpt-4o或gpt-4o-mini
475
+ "source": "chat/pro_image" # 固定 source
476
+ }
477
+
478
+ modified_request_body = json.dumps(text_to_image_json, ensure_ascii=False)
479
+ print("Modified Request JSON:", modified_request_body)
480
+
481
+ # 获取Bearer Token
482
+ tmp_token = BearerTokenGenerator.get_bearer(modified_request_body, path="/chats/stream")
483
+ if not tmp_token:
484
+ return send_error_response("无法生成 Bearer Token", status_code=500)
485
+
486
+ bearer_token, formatted_date = tmp_token
487
+
488
+ headers = {
489
+ "Date": formatted_date,
490
+ "Client-time-zone": "-05:00",
491
+ "Authorization": bearer_token,
492
+ "User-Agent": "ChatOn_Android/1.53.502",
493
+ "Accept-Language": "en-US",
494
+ "X-Cl-Options": "hb",
495
+ "Content-Type": "application/json; charset=UTF-8",
496
+ }
497
+
498
+ async with httpx.AsyncClient(timeout=None) as client:
499
+ try:
500
+ response = await client.post(
501
+ EXTERNAL_API_URL, headers=headers, content=modified_request_body, timeout=None
502
+ )
503
+ if response.status_code != 200:
504
+ return send_error_response(f"API 错误: {response.status_code}", status_code=500)
505
+
506
+ # 初始化用于拼接 URL 的字符串
507
+ url_builder = ""
508
+
509
+ # 读取 SSE 流并拼接 URL
510
+ async for line in response.aiter_lines():
511
+ if line.startswith("data: "):
512
+ data = line[6:].strip()
513
+ if data == "[DONE]":
514
+ break
515
+ try:
516
+ sse_json = json.loads(data)
517
+ if "choices" in sse_json:
518
+ for choice in sse_json["choices"]:
519
+ delta = choice.get("delta", {})
520
+ content = delta.get("content")
521
+ if content:
522
+ url_builder += content
523
+ except json.JSONDecodeError:
524
+ print("JSON解析错误")
525
+ continue
526
+
527
+ image_markdown = url_builder
528
+ # Step 1: 检查Markdown文本是否为空
529
+ if not image_markdown:
530
+ print("无法从 SSE 流中构建图像 Markdown。")
531
+ return send_error_response("无法从 SSE 流中构建图像 Markdown。", status_code=500)
532
+
533
+ # Step 2, 3, 4, 5: 处理图像
534
+ extracted_path = extract_path_from_markdown(image_markdown)
535
+ if not extracted_path:
536
+ print("无法从 Markdown 中提取路径。")
537
+ return send_error_response("无法从 Markdown 中提取路径。", status_code=500)
538
+
539
+ print(f"提取的路径: {extracted_path}")
540
+
541
+ # Step 5: 拼接最终的存储URL
542
+ storage_url = f"https://api.chaton.ai/storage/{extracted_path}"
543
+ print(f"存储URL: {storage_url}")
544
+
545
+ # 获取最终下载URL
546
+ final_download_url = await fetch_get_url_from_storage(storage_url)
547
+ if not final_download_url:
548
+ return send_error_response("无法从 storage URL 获取最终下载链接。", status_code=500)
549
+
550
+ print(f"Final Download URL: {final_download_url}")
551
+
552
+ # 下载图像
553
+ image_bytes = await download_image(final_download_url)
554
+ if not image_bytes:
555
+ return send_error_response("无法从 URL 下载图像。", status_code=500)
556
+
557
+ # 转换为 Base64
558
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
559
+
560
+ # 将图片保存到images目录并构建可访问的URL
561
+ filename = save_base64_image(image_base64)
562
+ base_url = app.state.base_url
563
+ accessible_url = f"{base_url}/images/{filename}"
564
+
565
+ # 根据 response_format 返回相应的响应
566
+ if response_format.lower() == "b64_json":
567
+ response_json = {
568
+ "data": [
569
+ {
570
+ "b64_json": image_base64
571
+ }
572
+ ]
573
+ }
574
+ return JSONResponse(content=response_json, status_code=200)
575
+ else:
576
+ # 构建包含可访问URL的响应
577
+ response_json = {
578
+ "data": [
579
+ {
580
+ "url": accessible_url
581
+ }
582
+ ]
583
+ }
584
+ return JSONResponse(content=response_json, status_code=200)
585
+ except httpx.RequestError as exc:
586
+ print(f"请求失败: {exc}")
587
+ return send_error_response(f"请求失败: {str(exc)}", status_code=500)
588
+ except Exception as exc:
589
+ print(f"内部服务器错误: {exc}")
590
+ return send_error_response(f"内部服务器错误: {str(exc)}", status_code=500)
591
+
592
+ # 运行服务器
593
+ def main():
594
+ parser = argparse.ArgumentParser(description="启动ChatOn API服务器")
595
+ parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
596
+ parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
597
+ args = parser.parse_args()
598
+ base_url = args.base_url
599
+ port = args.port
600
+
601
+ # 检查 API_KEY 是否设置
602
+ if not os.environ.get("API_KEY"):
603
+ print("警告: API_KEY 环境变量未设置。客户端验证将无法正常工作。")
604
+
605
+ # 确保 images 目录存在
606
+ if not os.path.exists("images"):
607
+ os.makedirs("images")
608
+
609
+ # 设置 FastAPI 应用的 state
610
+ app.state.base_url = base_url
611
+
612
+ print(f"Server started on port {port} with base_url: {base_url}")
613
+
614
+ # 运行FastAPI应用
615
+ uvicorn.run(app, host="0.0.0.0", port=port)
616
+
617
+ async def get_available_port(start_port: int = INITIAL_PORT, end_port: int = 65535) -> int:
618
+ """查找可用的端口号"""
619
+ for port in range(start_port, end_port + 1):
620
+ try:
621
+ server = await asyncio.start_server(lambda r, w: None, host="0.0.0.0", port=port)
622
+ server.close()
623
+ await server.wait_closed()
624
+ return port
625
+ except OSError:
626
+ continue
627
+ raise RuntimeError(f"No available ports between {start_port} and {end_port}")
628
+
629
+ if __name__ == "__main__":
630
+ main()