ZyphrZero commited on
Commit
9752c22
·
1 Parent(s): 9b0b6dd

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +211 -124
main.py CHANGED
@@ -1,13 +1,13 @@
1
-
2
 
3
  import json
4
  import re
5
  import time
6
  from datetime import datetime
7
- from typing import Dict, List, Optional, Any, Union, AsyncGenerator
8
  from urllib.parse import urljoin
9
 
10
- import httpx
11
  from fastapi import FastAPI, Request, Response, HTTPException, Header
12
  from fastapi.responses import StreamingResponse, JSONResponse
13
  from pydantic import BaseModel, Field
@@ -38,6 +38,156 @@ ORIGIN_BASE = "https://chat.z.ai"
38
  ANON_TOKEN_ENABLED = True
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # 数据结构定义
42
  class Message(BaseModel):
43
  role: str
@@ -151,32 +301,31 @@ def debug_log(format_str: str, *args):
151
 
152
 
153
  # 获取匿名token
154
- async def get_anonymous_token() -> str:
155
  """获取匿名token(每次对话使用不同token,避免共享记忆)"""
156
- async with httpx.AsyncClient(timeout=10.0) as client:
157
- headers = {
158
- "User-Agent": BROWSER_UA,
159
- "Accept": "*/*",
160
- "Accept-Language": "zh-CN,zh;q=0.9",
161
- "X-FE-Version": X_FE_VERSION,
162
- "sec-ch-ua": SEC_CH_UA,
163
- "sec-ch-ua-mobile": SEC_CH_UA_MOB,
164
- "sec-ch-ua-platform": SEC_CH_UA_PLAT,
165
- "Origin": ORIGIN_BASE,
166
- "Referer": f"{ORIGIN_BASE}/",
167
- }
168
-
169
- response = await client.get(f"{ORIGIN_BASE}/api/v1/auths/", headers=headers)
170
-
171
- if response.status_code != 200:
172
- raise Exception(f"anon token status={response.status_code}")
173
-
174
- data = response.json()
175
- token = data.get("token")
176
- if not token:
177
- raise Exception("anon token empty")
178
-
179
- return token
180
 
181
 
182
  # CORS中间件
@@ -192,8 +341,7 @@ async def add_cors_headers(request: Request, call_next):
192
 
193
  # OPTIONS处理器
194
  @app.options("/")
195
- @app.options("/v1/{path:path}")
196
- async def handle_options(path: str = ""):
197
  return Response(status_code=200)
198
 
199
 
@@ -219,19 +367,19 @@ async def handle_models():
219
  ),
220
  ]
221
  )
222
- return JSONResponse(content=response.model_dump(exclude_none=True))
223
 
224
 
225
  # 聊天完成接口
226
  @app.post("/v1/chat/completions")
227
  async def handle_chat_completions(
228
  request: OpenAIRequest,
229
- authorization: Optional[str] = Header(None)
230
  ):
231
  debug_log("收到chat completions请求")
232
 
233
  # 验证API Key
234
- if not authorization or not authorization.startswith("Bearer "):
235
  debug_log("缺少或无效的Authorization头")
236
  raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
237
 
@@ -287,12 +435,9 @@ async def handle_chat_completions(
287
  auth_token = UPSTREAM_TOKEN
288
  if ANON_TOKEN_ENABLED:
289
  try:
290
- token = await get_anonymous_token()
291
- if token:
292
- auth_token = token
293
- debug_log(f"匿名token获取成功: {token[:10] if len(token) > 10 else token}...")
294
- else:
295
- debug_log("获取到的匿名token为空,使用固定token")
296
  except Exception as e:
297
  debug_log(f"匿名token获取失败,回退固定token: {e}")
298
 
@@ -307,10 +452,10 @@ async def handle_chat_completions(
307
  }
308
  )
309
  else:
310
- return await handle_non_stream_response(upstream_req, chat_id, auth_token)
311
 
312
 
313
- async def call_upstream_with_headers(upstream_req: UpstreamRequest, referer_chat_id: str, auth_token: str) -> httpx.Response:
314
  """调用上游API"""
315
  headers = {
316
  "Content-Type": "application/json",
@@ -327,14 +472,15 @@ async def call_upstream_with_headers(upstream_req: UpstreamRequest, referer_chat
327
  }
328
 
329
  debug_log(f"调用上游API: {UPSTREAM_URL}")
330
- debug_log(f"上游请求体: {upstream_req.model_dump_json(exclude_none=True)}")
331
 
332
- async with httpx.AsyncClient(timeout=60.0) as client:
333
- response = await client.post(
334
- UPSTREAM_URL,
335
- json=upstream_req.model_dump(exclude_none=True),
336
- headers=headers
337
- )
 
338
 
339
  debug_log(f"上游响应状态: {response.status_code}")
340
  return response
@@ -361,22 +507,22 @@ def transform_thinking(s: str) -> str:
361
  return s.strip()
362
 
363
 
364
- async def handle_stream_response(upstream_req: UpstreamRequest, chat_id: str, auth_token: str) -> AsyncGenerator[str, None]:
365
  """处理流式响应"""
366
  debug_log(f"开始处理流式响应 (chat_id={chat_id})")
367
 
368
  try:
369
- response = await call_upstream_with_headers(upstream_req, chat_id, auth_token)
370
  except Exception as e:
371
  debug_log(f"调用上游失败: {e}")
372
- yield f"data: {{\"error\": \"Failed to call upstream\", \"type\": \"server_error\"}}\n\n"
373
  return
374
 
375
  if response.status_code != 200:
376
  debug_log(f"上游返回错误状态: {response.status_code}")
377
  if DEBUG_MODE:
378
  debug_log(f"上游错误响应: {response.text}")
379
- yield f"data: {{\"error\": \"Upstream error\", \"type\": \"upstream_error\"}}\n\n"
380
  return
381
 
382
  # 发送第一个chunk(role)
@@ -392,29 +538,13 @@ async def handle_stream_response(upstream_req: UpstreamRequest, chat_id: str, au
392
  )
393
  yield f"data: {first_chunk.model_dump_json()}\n\n"
394
 
395
- # 读取上游SSE
396
  debug_log("开始读取上游SSE流")
397
- line_count = 0
398
  sent_initial_answer = False
399
 
400
- try:
401
- async for line in response.aiter_lines():
402
- line_count += 1
403
-
404
- if not line.startswith("data: "):
405
- continue
406
-
407
- data_str = line[6:] # 去掉 "data: "
408
- if not data_str:
409
- continue
410
-
411
- debug_log(f"收到SSE数据 (第{line_count}行): {data_str}")
412
-
413
- try:
414
- upstream_data = UpstreamData.model_validate_json(data_str)
415
- except Exception as e:
416
- debug_log(f"SSE数据解析失败: {e}")
417
- continue
418
 
419
  # 错误检测
420
  if (upstream_data.error or
@@ -453,8 +583,7 @@ async def handle_stream_response(upstream_req: UpstreamRequest, chat_id: str, au
453
 
454
  out = upstream_data.data.edit_content
455
  if out:
456
- # 使用正则表达式分割,支持多行
457
- parts = re.split(r'</details>', out)
458
  if len(parts) > 1:
459
  content = parts[1]
460
  if content:
@@ -526,19 +655,16 @@ async def handle_stream_response(upstream_req: UpstreamRequest, chat_id: str, au
526
  )
527
  yield f"data: {end_chunk.model_dump_json()}\n\n"
528
  yield "data: [DONE]\n\n"
529
- debug_log(f"流式响应完成,共处理{line_count}行")
530
  break
531
- except Exception as e:
532
- debug_log(f"读取SSE流时发生错误: {e}")
533
- yield f"data: {{\"error\": \"Stream reading error\", \"type\": \"stream_error\"}}\n\n"
534
 
535
 
536
- async def handle_non_stream_response(upstream_req: UpstreamRequest, chat_id: str, auth_token: str) -> JSONResponse:
537
  """处理非流式响应"""
538
  debug_log(f"开始处理非流式响应 (chat_id={chat_id})")
539
 
540
  try:
541
- response = await call_upstream_with_headers(upstream_req, chat_id, auth_token)
542
  except Exception as e:
543
  debug_log(f"调用上游失败: {e}")
544
  raise HTTPException(status_code=502, detail="Failed to call upstream")
@@ -553,19 +679,9 @@ async def handle_non_stream_response(upstream_req: UpstreamRequest, chat_id: str
553
  full_content = []
554
  debug_log("开始收集完整响应内容")
555
 
556
- try:
557
- async for line in response.aiter_lines():
558
- if not line.startswith("data: "):
559
- continue
560
-
561
- data_str = line[6:]
562
- if not data_str:
563
- continue
564
-
565
- try:
566
- upstream_data = UpstreamData.model_validate_json(data_str)
567
- except Exception:
568
- continue
569
 
570
  if upstream_data.data.delta_content:
571
  out = upstream_data.data.delta_content
@@ -579,9 +695,6 @@ async def handle_non_stream_response(upstream_req: UpstreamRequest, chat_id: str
579
  if upstream_data.data.done or upstream_data.data.phase == "done":
580
  debug_log("检测到完成信号,停止收集")
581
  break
582
- except Exception as e:
583
- debug_log(f"读取响应流时发生错误: {e}")
584
- raise HTTPException(status_code=502, detail="Failed to read upstream response")
585
 
586
  final_content = "".join(full_content)
587
  debug_log(f"内容收集完成,最终长度: {len(final_content)}")
@@ -604,17 +717,7 @@ async def handle_non_stream_response(upstream_req: UpstreamRequest, chat_id: str
604
  )
605
 
606
  debug_log("非流式响应发送完成")
607
- # 添加CORS头
608
- headers = {
609
- "Access-Control-Allow-Origin": "*",
610
- "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
611
- "Access-Control-Allow-Headers": "Content-Type, Authorization",
612
- "Access-Control-Allow-Credentials": "true"
613
- }
614
- return JSONResponse(
615
- content=response_data.model_dump(exclude_none=True),
616
- headers=headers
617
- )
618
 
619
 
620
  # 根路径处理器
@@ -623,22 +726,6 @@ async def root():
623
  return {"message": "OpenAI Compatible API Server"}
624
 
625
 
626
- # 根路径处理器
627
- @app.get("/")
628
- async def root():
629
- return {"message": "OpenAI Compatible API Server"}
630
-
631
-
632
- # 健康检查接口
633
- @app.get("/health")
634
- async def health_check():
635
- return {"status": "ok", "timestamp": int(time.time())}
636
-
637
-
638
  if __name__ == "__main__":
639
  import uvicorn
640
- print(f"OpenAI兼容API服务器启动在端口 {PORT}")
641
- print(f"模型: {DEFAULT_MODEL_NAME}")
642
- print(f"上游: {UPSTREAM_URL}")
643
- print(f"Debug模式: {DEBUG_MODE}")
644
- uvicorn.run("main:app", host="0.0.0.0", port=PORT, reload=DEBUG_MODE)
 
1
+ # -*- coding: utf-8 -*-
2
 
3
  import json
4
  import re
5
  import time
6
  from datetime import datetime
7
+ from typing import Dict, List, Optional, Any, Union, Generator
8
  from urllib.parse import urljoin
9
 
10
+ import requests
11
  from fastapi import FastAPI, Request, Response, HTTPException, Header
12
  from fastapi.responses import StreamingResponse, JSONResponse
13
  from pydantic import BaseModel, Field
 
38
  ANON_TOKEN_ENABLED = True
39
 
40
 
41
+ # SSE 解析生成器
42
+ class SSEParser:
43
+ """统一的 SSE (Server-Sent Events) 解析生成器"""
44
+
45
+ def __init__(self, response, debug_mode=False):
46
+ """初始化 SSE 解析器
47
+
48
+ Args:
49
+ response: requests.Response 对象,需要设置 stream=True
50
+ debug_mode: 是否启用调试模式
51
+ """
52
+ self.response = response
53
+ self.debug_mode = debug_mode
54
+ self.buffer = ""
55
+ self.line_count = 0
56
+
57
+ def debug_log(self, format_str: str, *args):
58
+ """调试日志"""
59
+ if self.debug_mode:
60
+ print(f"[SSE_PARSER] {format_str % args}")
61
+
62
+ def iter_events(self):
63
+ """生成器,逐个产生 SSE 事件
64
+
65
+ Yields:
66
+ dict: 解析后的 SSE 事件数据
67
+ """
68
+ self.debug_log("开始解析 SSE 流")
69
+
70
+ for line in self.response.iter_lines():
71
+ self.line_count += 1
72
+
73
+ # 处理空行
74
+ if not line:
75
+ continue
76
+
77
+ # 解码字节串
78
+ if isinstance(line, bytes):
79
+ try:
80
+ line = line.decode('utf-8')
81
+ except UnicodeDecodeError:
82
+ self.debug_log(f"第{self.line_count}行解码失败,跳过")
83
+ continue
84
+
85
+ # 处理注释行
86
+ if line.startswith(':'):
87
+ continue
88
+
89
+ # 解析字段
90
+ if ':' in line:
91
+ field, value = line.split(':', 1)
92
+ field = field.strip()
93
+ value = value.lstrip() # 去掉冒号后的空格
94
+
95
+ if field == 'data':
96
+ # 处理数据字段
97
+ self.debug_log(f"收到数据 (第{self.line_count}行): {value}")
98
+
99
+ # 尝试解析 JSON
100
+ try:
101
+ data = json.loads(value)
102
+ yield {
103
+ 'type': 'data',
104
+ 'data': data,
105
+ 'raw': value
106
+ }
107
+ except json.JSONDecodeError:
108
+ # 不是 JSON,作为原始数据返回
109
+ yield {
110
+ 'type': 'data',
111
+ 'data': value,
112
+ 'raw': value,
113
+ 'is_json': False
114
+ }
115
+
116
+ elif field == 'event':
117
+ # 处理事件类型
118
+ yield {
119
+ 'type': 'event',
120
+ 'event': value
121
+ }
122
+
123
+ elif field == 'id':
124
+ # 处理事件 ID
125
+ yield {
126
+ 'type': 'id',
127
+ 'id': value
128
+ }
129
+
130
+ elif field == 'retry':
131
+ # 处理重试时间
132
+ try:
133
+ retry = int(value)
134
+ yield {
135
+ 'type': 'retry',
136
+ 'retry': retry
137
+ }
138
+ except ValueError:
139
+ self.debug_log(f"无效的 retry 值: {value}")
140
+
141
+ def iter_data_only(self):
142
+ """生成器,只产生数据事件
143
+
144
+ Yields:
145
+ dict: 仅包含数据的 SSE 事件
146
+ """
147
+ for event in self.iter_events():
148
+ if event['type'] == 'data':
149
+ yield event
150
+
151
+ def iter_json_data(self, model_class=None):
152
+ """生成器,只产生 JSON 数据事件
153
+
154
+ Args:
155
+ model_class: 可选的 Pydantic 模型类,用于验证数据
156
+
157
+ Yields:
158
+ dict: 包含解析后的 JSON 数据的事件
159
+ """
160
+ for event in self.iter_events():
161
+ if event['type'] == 'data' and event.get('is_json', True):
162
+ try:
163
+ if model_class:
164
+ # 使用 Pydantic 模型验证
165
+ data = model_class.model_validate_json(event['raw'])
166
+ yield {
167
+ 'type': 'data',
168
+ 'data': data,
169
+ 'raw': event['raw']
170
+ }
171
+ else:
172
+ yield event
173
+ except Exception as e:
174
+ self.debug_log(f"数据验证失败: {e}")
175
+ continue
176
+
177
+ def close(self):
178
+ """关闭响应连接"""
179
+ if hasattr(self.response, 'close'):
180
+ self.response.close()
181
+
182
+ def __enter__(self):
183
+ """支持上下文管理器"""
184
+ return self
185
+
186
+ def __exit__(self, exc_type, exc_val, exc_tb):
187
+ """退出上下文时自动关闭连接"""
188
+ self.close()
189
+
190
+
191
  # 数据结构定义
192
  class Message(BaseModel):
193
  role: str
 
301
 
302
 
303
  # 获取匿名token
304
+ def get_anonymous_token() -> str:
305
  """获取匿名token(每次对话使用不同token,避免共享记忆)"""
306
+ headers = {
307
+ "User-Agent": BROWSER_UA,
308
+ "Accept": "*/*",
309
+ "Accept-Language": "zh-CN,zh;q=0.9",
310
+ "X-FE-Version": X_FE_VERSION,
311
+ "sec-ch-ua": SEC_CH_UA,
312
+ "sec-ch-ua-mobile": SEC_CH_UA_MOB,
313
+ "sec-ch-ua-platform": SEC_CH_UA_PLAT,
314
+ "Origin": ORIGIN_BASE,
315
+ "Referer": f"{ORIGIN_BASE}/",
316
+ }
317
+
318
+ response = requests.get(f"{ORIGIN_BASE}/api/v1/auths/", headers=headers, timeout=10.0)
319
+
320
+ if response.status_code != 200:
321
+ raise Exception(f"anon token status={response.status_code}")
322
+
323
+ data = response.json()
324
+ token = data.get("token")
325
+ if not token:
326
+ raise Exception("anon token empty")
327
+
328
+ return token
 
329
 
330
 
331
  # CORS中间件
 
341
 
342
  # OPTIONS处理器
343
  @app.options("/")
344
+ async def handle_options():
 
345
  return Response(status_code=200)
346
 
347
 
 
367
  ),
368
  ]
369
  )
370
+ return response
371
 
372
 
373
  # 聊天完成接口
374
  @app.post("/v1/chat/completions")
375
  async def handle_chat_completions(
376
  request: OpenAIRequest,
377
+ authorization: str = Header(...)
378
  ):
379
  debug_log("收到chat completions请求")
380
 
381
  # 验证API Key
382
+ if not authorization.startswith("Bearer "):
383
  debug_log("缺少或无效的Authorization头")
384
  raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
385
 
 
435
  auth_token = UPSTREAM_TOKEN
436
  if ANON_TOKEN_ENABLED:
437
  try:
438
+ token = get_anonymous_token()
439
+ auth_token = token
440
+ debug_log(f"匿名token获取成功: {token[:10]}...")
 
 
 
441
  except Exception as e:
442
  debug_log(f"匿名token获取失败,回退固定token: {e}")
443
 
 
452
  }
453
  )
454
  else:
455
+ return handle_non_stream_response(upstream_req, chat_id, auth_token)
456
 
457
 
458
+ def call_upstream_with_headers(upstream_req: UpstreamRequest, referer_chat_id: str, auth_token: str) -> requests.Response:
459
  """调用上游API"""
460
  headers = {
461
  "Content-Type": "application/json",
 
472
  }
473
 
474
  debug_log(f"调用上游API: {UPSTREAM_URL}")
475
+ debug_log(f"上游请求体: {upstream_req.model_dump_json()}")
476
 
477
+ response = requests.post(
478
+ UPSTREAM_URL,
479
+ json=upstream_req.model_dump(exclude_none=True),
480
+ headers=headers,
481
+ timeout=60.0,
482
+ stream=True
483
+ )
484
 
485
  debug_log(f"上游响应状态: {response.status_code}")
486
  return response
 
507
  return s.strip()
508
 
509
 
510
+ def handle_stream_response(upstream_req: UpstreamRequest, chat_id: str, auth_token: str):
511
  """处理流式响应"""
512
  debug_log(f"开始处理流式响应 (chat_id={chat_id})")
513
 
514
  try:
515
+ response = call_upstream_with_headers(upstream_req, chat_id, auth_token)
516
  except Exception as e:
517
  debug_log(f"调用上游失败: {e}")
518
+ yield "data: {\"error\": \"Failed to call upstream\"}\n\n"
519
  return
520
 
521
  if response.status_code != 200:
522
  debug_log(f"上游返回错误状态: {response.status_code}")
523
  if DEBUG_MODE:
524
  debug_log(f"上游错误响应: {response.text}")
525
+ yield "data: {\"error\": \"Upstream error\"}\n\n"
526
  return
527
 
528
  # 发送第一个chunk(role)
 
538
  )
539
  yield f"data: {first_chunk.model_dump_json()}\n\n"
540
 
541
+ # 使用 SSE 解析器处理流
542
  debug_log("开始读取上游SSE流")
 
543
  sent_initial_answer = False
544
 
545
+ with SSEParser(response, debug_mode=DEBUG_MODE) as parser:
546
+ for event in parser.iter_json_data(UpstreamData):
547
+ upstream_data = event['data']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
 
549
  # 错误检测
550
  if (upstream_data.error or
 
583
 
584
  out = upstream_data.data.edit_content
585
  if out:
586
+ parts = out.split("</details>")
 
587
  if len(parts) > 1:
588
  content = parts[1]
589
  if content:
 
655
  )
656
  yield f"data: {end_chunk.model_dump_json()}\n\n"
657
  yield "data: [DONE]\n\n"
658
+ debug_log(f"流式响应完成")
659
  break
 
 
 
660
 
661
 
662
+ def handle_non_stream_response(upstream_req: UpstreamRequest, chat_id: str, auth_token: str) -> JSONResponse:
663
  """处理非流式响应"""
664
  debug_log(f"开始处理非流式响应 (chat_id={chat_id})")
665
 
666
  try:
667
+ response = call_upstream_with_headers(upstream_req, chat_id, auth_token)
668
  except Exception as e:
669
  debug_log(f"调用上游失败: {e}")
670
  raise HTTPException(status_code=502, detail="Failed to call upstream")
 
679
  full_content = []
680
  debug_log("开始收集完整响应内容")
681
 
682
+ with SSEParser(response, debug_mode=DEBUG_MODE) as parser:
683
+ for event in parser.iter_json_data(UpstreamData):
684
+ upstream_data = event['data']
 
 
 
 
 
 
 
 
 
 
685
 
686
  if upstream_data.data.delta_content:
687
  out = upstream_data.data.delta_content
 
695
  if upstream_data.data.done or upstream_data.data.phase == "done":
696
  debug_log("检测到完成信号,停止收集")
697
  break
 
 
 
698
 
699
  final_content = "".join(full_content)
700
  debug_log(f"内容收集完成,最终长度: {len(final_content)}")
 
717
  )
718
 
719
  debug_log("非流式响应发送完成")
720
+ return JSONResponse(content=response_data.model_dump(exclude_none=True))
 
 
 
 
 
 
 
 
 
 
721
 
722
 
723
  # 根路径处理器
 
726
  return {"message": "OpenAI Compatible API Server"}
727
 
728
 
 
 
 
 
 
 
 
 
 
 
 
 
729
  if __name__ == "__main__":
730
  import uvicorn
731
+ uvicorn.run(app, host="0.0.0.0", port=PORT)