dan92 commited on
Commit
a8175d4
·
verified ·
1 Parent(s): d405779

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +3 -0
  2. retry_middleware.py +50 -89
Dockerfile CHANGED
@@ -1,5 +1,8 @@
1
  FROM hpyp/bbapi:latest
2
 
 
 
 
3
  # 复制重试中间件文件到容器中
4
  COPY retry_middleware.py /app/retry_middleware.py
5
 
 
1
  FROM hpyp/bbapi:latest
2
 
3
+ # 安装必要的依赖
4
+ RUN pip install starlette
5
+
6
  # 复制重试中间件文件到容器中
7
  COPY retry_middleware.py /app/retry_middleware.py
8
 
retry_middleware.py CHANGED
@@ -1,112 +1,73 @@
 
1
  from fastapi import Request, Response
2
  from fastapi.responses import JSONResponse
3
  import asyncio
4
  import json
5
- from typing import Callable
6
  import logging
7
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
- class RetryMiddleware:
12
- def __init__(
13
- self,
14
- app,
15
- max_retries: int = 3,
16
- delay: float = 1.0
17
- ):
18
- self.app = app
19
  self.max_retries = max_retries
20
  self.delay = delay
21
 
22
- async def __call__(self, scope, receive, send):
23
- if scope["type"] != "http":
24
- return await self.app(scope, receive, send)
 
25
 
26
- async def wrapped_send(message):
27
- if message["type"] == "http.response.body":
28
- body = message.get("body", b"")
29
- if body:
30
- try:
31
- response_data = json.loads(body.decode())
32
- if isinstance(response_data, dict) and response_data.get("error"):
33
- error_msg = response_data["error"].lower()
34
- if "content is not safe" in error_msg:
35
- # 需要���试
36
- return await self.handle_retry(scope, receive, send)
37
- except json.JSONDecodeError:
38
- pass
39
- await send(message)
40
-
41
- await self.app(scope, receive, wrapped_send)
42
-
43
- async def handle_retry(self, scope, receive, send):
44
- original_receive = receive
45
  for attempt in range(self.max_retries):
46
  try:
47
- logger.info(f"正在进行第 {attempt + 1} 次尝试...")
 
 
 
 
 
 
48
 
49
- # 重新构造请求
50
- request_body = b""
51
- more_body = True
52
- while more_body:
53
- message = await original_receive()
54
- if message["type"] == "http.request":
55
- request_body += message.get("body", b"")
56
- more_body = message.get("more_body", False)
57
-
58
- async def modified_receive():
59
- if not hasattr(modified_receive, 'called'):
60
- modified_receive.called = True
61
- return {
62
- "type": "http.request",
63
- "body": request_body,
64
- "more_body": False
65
- }
66
- return {"type": "http.disconnect"}
67
-
68
- response_sent = False
69
- async def modified_send(message):
70
- nonlocal response_sent
71
- if message["type"] == "http.response.start":
72
- await send(message)
73
- elif message["type"] == "http.response.body":
74
- body = message.get("body", b"")
75
- if body:
76
- try:
77
- response_data = json.loads(body.decode())
78
- if isinstance(response_data, dict) and response_data.get("error"):
79
- error_msg = response_data["error"].lower()
80
- if "content is not safe" in error_msg:
81
- if attempt < self.max_retries - 1:
82
- logger.info(f"检测到内容安全问题,等待 {self.delay} 秒后重试...")
83
- await asyncio.sleep(self.delay)
84
- return
85
- except json.JSONDecodeError:
86
- pass
87
- response_sent = True
88
- await send(message)
89
-
90
- await self.app(scope, modified_receive, modified_send)
91
- if response_sent:
92
- break
93
 
94
  except Exception as e:
95
  logger.error(f"重试过程中发生错误: {str(e)}")
96
  if attempt == self.max_retries - 1:
97
- # 如果是最后一次尝试,发送错误响应
98
- error_response = {
99
- "error": f"在 {self.max_retries} 次尝试后仍然失败: {str(e)}"
100
- }
101
- await send({
102
- "type": "http.response.start",
103
- "status": 500,
104
- "headers": [(b"content-type", b"application/json")]
105
- })
106
- await send({
107
- "type": "http.response.body",
108
- "body": json.dumps(error_response).encode()
109
- })
110
 
111
  # 使用方法:
112
  """
 
1
+ from starlette.middleware.base import BaseHTTPMiddleware
2
  from fastapi import Request, Response
3
  from fastapi.responses import JSONResponse
4
  import asyncio
5
  import json
 
6
  import logging
7
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
+ class RetryMiddleware(BaseHTTPMiddleware):
12
+ def __init__(self, app, max_retries: int = 3, delay: float = 1.0):
13
+ super().__init__(app)
 
 
 
 
 
14
  self.max_retries = max_retries
15
  self.delay = delay
16
 
17
+ async def dispatch(self, request: Request, call_next):
18
+ # 只处理 /api/v1/chat/completions 路径的请求
19
+ if not request.url.path.endswith('/api/v1/chat/completions'):
20
+ return await call_next(request)
21
 
22
+ # 读取原始请求体
23
+ body = await request.body()
24
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  for attempt in range(self.max_retries):
26
  try:
27
+ # 构造新的请求
28
+ async def receive():
29
+ return {
30
+ "type": "http.request",
31
+ "body": body,
32
+ "more_body": False,
33
+ }
34
 
35
+ # ���送请求并获取响应
36
+ response = await call_next(Request(request.scope, receive))
37
+
38
+ # 读取响应内容
39
+ response_body = b""
40
+ async for chunk in response.body_iterator:
41
+ response_body += chunk
42
+
43
+ try:
44
+ response_data = json.loads(response_body)
45
+ if isinstance(response_data, dict):
46
+ error = response_data.get('error', '')
47
+ if isinstance(error, str) and 'content is not safe' in error.lower():
48
+ if attempt < self.max_retries - 1:
49
+ logger.info(f"检测到内容安全问题,等待 {self.delay} 秒后进行第 {attempt + 2} 次重试...")
50
+ await asyncio.sleep(self.delay)
51
+ continue
52
+ except json.JSONDecodeError:
53
+ pass
54
+
55
+ # 如果没有错误或是最后一次尝试,返回响应
56
+ return Response(
57
+ content=response_body,
58
+ status_code=response.status_code,
59
+ headers=dict(response.headers),
60
+ media_type=response.media_type
61
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  except Exception as e:
64
  logger.error(f"重试过程中发生错误: {str(e)}")
65
  if attempt == self.max_retries - 1:
66
+ return JSONResponse(
67
+ status_code=500,
68
+ content={"error": f"在 {self.max_retries} 次尝试后仍然失败: {str(e)}"}
69
+ )
70
+ await asyncio.sleep(self.delay)
 
 
 
 
 
 
 
 
71
 
72
  # 使用方法:
73
  """