dan92 commited on
Commit
ab06a63
·
verified ·
1 Parent(s): a8defe4

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. retry_middleware.py +13 -38
Dockerfile CHANGED
@@ -11,7 +11,7 @@ RUN echo 'from fastapi import FastAPI\n\
11
  from retry_middleware import RetryMiddleware\n\
12
  from main import app\n\
13
  \n\
14
- app.add_middleware(RetryMiddleware, max_retries=5, initial_delay=0.5)' > /app/wrapper.py
15
 
16
  # ENV APP_SECRET=
17
 
 
11
  from retry_middleware import RetryMiddleware\n\
12
  from main import app\n\
13
  \n\
14
+ app.add_middleware(RetryMiddleware)' > /app/wrapper.py
15
 
16
  # ENV APP_SECRET=
17
 
retry_middleware.py CHANGED
@@ -1,20 +1,16 @@
1
  from starlette.middleware.base import BaseHTTPMiddleware
2
  from fastapi import Request, Response
3
  from fastapi.responses import JSONResponse
4
- import asyncio
5
  import json
6
  import logging
7
- import random
8
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
  class RetryMiddleware(BaseHTTPMiddleware):
13
- def __init__(self, app, max_retries: int = 5, initial_delay: float = 0.5):
14
  super().__init__(app)
15
- self.max_retries = max_retries
16
- self.initial_delay = initial_delay
17
-
18
  async def should_retry_response(self, response_data):
19
  """检查响应是否需要重试"""
20
  if isinstance(response_data, dict):
@@ -45,11 +41,9 @@ class RetryMiddleware(BaseHTTPMiddleware):
45
 
46
  # 读取原始请求体
47
  body = await request.body()
48
- original_response = None
49
- best_response = None
50
  retry_count = 0
51
 
52
- while retry_count < self.max_retries:
53
  try:
54
  # 构造请求
55
  async def receive():
@@ -68,23 +62,19 @@ class RetryMiddleware(BaseHTTPMiddleware):
68
  try:
69
  response_data = json.loads(response_body)
70
 
71
- # 第一次响应,保存作为原始响应
72
- if original_response is None:
73
- original_response = response_data
74
-
75
  # 检查响应是否需要重试
76
  if await self.should_retry_response(response_data):
77
  retry_count += 1
78
- if retry_count < self.max_retries:
79
- # 使用指数退避和随机抖动
80
- delay = self.initial_delay * (2 ** retry_count) * (0.5 + random.random())
81
- logger.info(f"检测到内容安全问题,等待 {delay:.2f} 秒后进行第 {retry_count + 1} 次重试...")
82
- await asyncio.sleep(delay)
83
- continue
84
  else:
85
- # 如果响应正常,保存为最佳响应
86
- best_response = response_data
87
- break
 
 
 
 
88
 
89
  except json.JSONDecodeError:
90
  # 如果响应不是JSON格式,直接返回
@@ -98,22 +88,7 @@ class RetryMiddleware(BaseHTTPMiddleware):
98
  except Exception as e:
99
  logger.error(f"重试过程中发生错误: {str(e)}")
100
  retry_count += 1
101
- if retry_count < self.max_retries:
102
- delay = self.initial_delay * (2 ** retry_count) * (0.5 + random.random())
103
- await asyncio.sleep(delay)
104
- continue
105
- else:
106
- return JSONResponse(
107
- status_code=500,
108
- content={"error": f"在 {self.max_retries} 次尝试后仍然失败: {str(e)}"}
109
- )
110
-
111
- # 返回最佳响应,如果没有最佳响应则返回原始响应
112
- final_response = best_response or original_response
113
- return JSONResponse(
114
- content=final_response,
115
- status_code=200
116
- )
117
 
118
  # 使用方法:
119
  """
 
1
  from starlette.middleware.base import BaseHTTPMiddleware
2
  from fastapi import Request, Response
3
  from fastapi.responses import JSONResponse
 
4
  import json
5
  import logging
 
6
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
  class RetryMiddleware(BaseHTTPMiddleware):
11
+ def __init__(self, app):
12
  super().__init__(app)
13
+
 
 
14
  async def should_retry_response(self, response_data):
15
  """检查响应是否需要重试"""
16
  if isinstance(response_data, dict):
 
41
 
42
  # 读取原始请求体
43
  body = await request.body()
 
 
44
  retry_count = 0
45
 
46
+ while True: # 无限循环,直到获得正确响应
47
  try:
48
  # 构造请求
49
  async def receive():
 
62
  try:
63
  response_data = json.loads(response_body)
64
 
 
 
 
 
65
  # 检查响应是否需要重试
66
  if await self.should_retry_response(response_data):
67
  retry_count += 1
68
+ logger.info(f"检测到内容安全问题,立即进行第 {retry_count + 1} 次重试...")
69
+ continue
 
 
 
 
70
  else:
71
+ # 如果响应正常,直接返回
72
+ return Response(
73
+ content=response_body,
74
+ status_code=response.status_code,
75
+ headers=dict(response.headers),
76
+ media_type=response.media_type
77
+ )
78
 
79
  except json.JSONDecodeError:
80
  # 如果响应不是JSON格式,直接返回
 
88
  except Exception as e:
89
  logger.error(f"重试过程中发生错误: {str(e)}")
90
  retry_count += 1
91
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # 使用方法:
94
  """