dan92 commited on
Commit
ba9bc4b
·
verified ·
1 Parent(s): a8175d4

Upload retry_middleware.py

Browse files
Files changed (1) hide show
  1. retry_middleware.py +74 -28
retry_middleware.py CHANGED
@@ -4,15 +4,39 @@ from fastapi.responses import JSONResponse
4
  import asyncio
5
  import json
6
  import logging
 
7
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
  class RetryMiddleware(BaseHTTPMiddleware):
12
- def __init__(self, app, max_retries: int = 3, delay: float = 1.0):
13
  super().__init__(app)
14
  self.max_retries = max_retries
15
- self.delay = delay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  async def dispatch(self, request: Request, call_next):
18
  # 只处理 /api/v1/chat/completions 路径的请求
@@ -21,53 +45,75 @@ class RetryMiddleware(BaseHTTPMiddleware):
21
 
22
  # 读取原始请求体
23
  body = await request.body()
 
 
 
24
 
25
- for attempt in range(self.max_retries):
26
  try:
27
- # 构造新的请求
28
  async def receive():
29
  return {
30
  "type": "http.request",
31
  "body": body,
32
  "more_body": False,
33
  }
34
-
35
- # ���送请求并获取响应
36
  response = await call_next(Request(request.scope, receive))
37
-
38
- # 读取响应内容
39
  response_body = b""
40
  async for chunk in response.body_iterator:
41
  response_body += chunk
42
-
43
  try:
44
  response_data = json.loads(response_body)
45
- if isinstance(response_data, dict):
46
- error = response_data.get('error', '')
47
- if isinstance(error, str) and 'content is not safe' in error.lower():
48
- if attempt < self.max_retries - 1:
49
- logger.info(f"检测到内容安全问题,等待 {self.delay} 秒后进行第 {attempt + 2} 次重试...")
50
- await asyncio.sleep(self.delay)
51
- continue
 
 
 
 
 
 
 
 
 
 
 
 
52
  except json.JSONDecodeError:
53
- pass
54
-
55
- # 如果没有错误或是最后一次尝试,返回响应
56
- return Response(
57
- content=response_body,
58
- status_code=response.status_code,
59
- headers=dict(response.headers),
60
- media_type=response.media_type
61
- )
62
-
63
  except Exception as e:
64
  logger.error(f"重试过程中发生错误: {str(e)}")
65
- if attempt == self.max_retries - 1:
 
 
 
 
 
66
  return JSONResponse(
67
  status_code=500,
68
  content={"error": f"在 {self.max_retries} 次尝试后仍然失败: {str(e)}"}
69
  )
70
- await asyncio.sleep(self.delay)
 
 
 
 
 
 
71
 
72
  # 使用方法:
73
  """
 
4
  import asyncio
5
  import json
6
  import logging
7
+ import random
8
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
  class RetryMiddleware(BaseHTTPMiddleware):
13
+ def __init__(self, app, max_retries: int = 5, initial_delay: float = 0.5):
14
  super().__init__(app)
15
  self.max_retries = max_retries
16
+ self.initial_delay = initial_delay
17
+
18
+ async def should_retry_response(self, response_data):
19
+ """检查响应是否需要重试"""
20
+ if isinstance(response_data, dict):
21
+ # 检查错误信息
22
+ error = response_data.get('error', '')
23
+ if isinstance(error, str) and 'content is not safe' in error.lower():
24
+ return True
25
+
26
+ # 检查响应内容
27
+ choices = response_data.get('choices', [])
28
+ if choices:
29
+ content = choices[0].get('message', {}).get('content', '')
30
+ if 'content is not safe' in content.lower():
31
+ return True
32
+
33
+ # 检查是否有部分内容包含错误信息
34
+ if isinstance(content, str):
35
+ lines = content.split('\n')
36
+ for line in lines:
37
+ if 'content is not safe' in line.lower():
38
+ return True
39
+ return False
40
 
41
  async def dispatch(self, request: Request, call_next):
42
  # 只处理 /api/v1/chat/completions 路径的请求
 
45
 
46
  # 读取原始请求体
47
  body = await request.body()
48
+ original_response = None
49
+ best_response = None
50
+ retry_count = 0
51
 
52
+ while retry_count < self.max_retries:
53
  try:
54
+ # 构造请求
55
  async def receive():
56
  return {
57
  "type": "http.request",
58
  "body": body,
59
  "more_body": False,
60
  }
61
+
62
+ # 送请求并获取响应
63
  response = await call_next(Request(request.scope, receive))
 
 
64
  response_body = b""
65
  async for chunk in response.body_iterator:
66
  response_body += chunk
67
+
68
  try:
69
  response_data = json.loads(response_body)
70
+
71
+ # 第一次响应,保存作为原始响应
72
+ if original_response is None:
73
+ original_response = response_data
74
+
75
+ # 检查响应是否需要重试
76
+ if await self.should_retry_response(response_data):
77
+ retry_count += 1
78
+ if retry_count < self.max_retries:
79
+ # 使用指数退避和随机抖动
80
+ delay = self.initial_delay * (2 ** retry_count) * (0.5 + random.random())
81
+ logger.info(f"检测到内容安全问题,等待 {delay:.2f} 秒后进行第 {retry_count + 1} 次重试...")
82
+ await asyncio.sleep(delay)
83
+ continue
84
+ else:
85
+ # 如果响应正常,保存为最佳响应
86
+ best_response = response_data
87
+ break
88
+
89
  except json.JSONDecodeError:
90
+ # 如果响应不是JSON格式,直接返回
91
+ return Response(
92
+ content=response_body,
93
+ status_code=response.status_code,
94
+ headers=dict(response.headers),
95
+ media_type=response.media_type
96
+ )
97
+
 
 
98
  except Exception as e:
99
  logger.error(f"重试过程中发生错误: {str(e)}")
100
+ retry_count += 1
101
+ if retry_count < self.max_retries:
102
+ delay = self.initial_delay * (2 ** retry_count) * (0.5 + random.random())
103
+ await asyncio.sleep(delay)
104
+ continue
105
+ else:
106
  return JSONResponse(
107
  status_code=500,
108
  content={"error": f"在 {self.max_retries} 次尝试后仍然失败: {str(e)}"}
109
  )
110
+
111
+ # 返回最佳响应,如果没有最佳响应则返回原始响应
112
+ final_response = best_response or original_response
113
+ return JSONResponse(
114
+ content=final_response,
115
+ status_code=200
116
+ )
117
 
118
  # 使用方法:
119
  """