MiroFish / backend /app /utils /retry.py
Codex Deploy
Deploy MiroFish to HF Space
ebdfd3b
"""
API调用重试机制
用于处理LLM等外部API调用的重试逻辑
"""
import time
import random
import functools
from typing import Callable, Any, Optional, Type, Tuple
from ..utils.logger import get_logger
logger = get_logger('mirofish.retry')
def retry_with_backoff(
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 30.0,
backoff_factor: float = 2.0,
jitter: bool = True,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None
):
"""
带指数退避的重试装饰器
Args:
max_retries: 最大重试次数
initial_delay: 初始延迟(秒)
max_delay: 最大延迟(秒)
backoff_factor: 退避因子
jitter: 是否添加随机抖动
exceptions: 需要重试的异常类型
on_retry: 重试时的回调函数 (exception, retry_count)
Usage:
@retry_with_backoff(max_retries=3)
def call_llm_api():
...
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
last_exception = None
delay = initial_delay
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt == max_retries:
logger.error(f"函数 {func.__name__}{max_retries} 次重试后仍失败: {str(e)}")
raise
# 计算延迟
current_delay = min(delay, max_delay)
if jitter:
current_delay = current_delay * (0.5 + random.random())
logger.warning(
f"函数 {func.__name__}{attempt + 1} 次尝试失败: {str(e)}, "
f"{current_delay:.1f}秒后重试..."
)
if on_retry:
on_retry(e, attempt + 1)
time.sleep(current_delay)
delay *= backoff_factor
raise last_exception
return wrapper
return decorator
def retry_with_backoff_async(
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 30.0,
backoff_factor: float = 2.0,
jitter: bool = True,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None
):
"""
异步版本的重试装饰器
"""
import asyncio
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs) -> Any:
last_exception = None
delay = initial_delay
for attempt in range(max_retries + 1):
try:
return await func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt == max_retries:
logger.error(f"异步函数 {func.__name__}{max_retries} 次重试后仍失败: {str(e)}")
raise
current_delay = min(delay, max_delay)
if jitter:
current_delay = current_delay * (0.5 + random.random())
logger.warning(
f"异步函数 {func.__name__}{attempt + 1} 次尝试失败: {str(e)}, "
f"{current_delay:.1f}秒后重试..."
)
if on_retry:
on_retry(e, attempt + 1)
await asyncio.sleep(current_delay)
delay *= backoff_factor
raise last_exception
return wrapper
return decorator
class RetryableAPIClient:
"""
可重试的API客户端封装
"""
def __init__(
self,
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 30.0,
backoff_factor: float = 2.0
):
self.max_retries = max_retries
self.initial_delay = initial_delay
self.max_delay = max_delay
self.backoff_factor = backoff_factor
def call_with_retry(
self,
func: Callable,
*args,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
**kwargs
) -> Any:
"""
执行函数调用并在失败时重试
Args:
func: 要调用的函数
*args: 函数参数
exceptions: 需要重试的异常类型
**kwargs: 函数关键字参数
Returns:
函数返回值
"""
last_exception = None
delay = self.initial_delay
for attempt in range(self.max_retries + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt == self.max_retries:
logger.error(f"API调用在 {self.max_retries} 次重试后仍失败: {str(e)}")
raise
current_delay = min(delay, self.max_delay)
current_delay = current_delay * (0.5 + random.random())
logger.warning(
f"API调用第 {attempt + 1} 次尝试失败: {str(e)}, "
f"{current_delay:.1f}秒后重试..."
)
time.sleep(current_delay)
delay *= self.backoff_factor
raise last_exception
def call_batch_with_retry(
self,
items: list,
process_func: Callable,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
continue_on_failure: bool = True
) -> Tuple[list, list]:
"""
批量调用并对每个失败项单独重试
Args:
items: 要处理的项目列表
process_func: 处理函数,接收单个item作为参数
exceptions: 需要重试的异常类型
continue_on_failure: 单项失败后是否继续处理其他项
Returns:
(成功结果列表, 失败项列表)
"""
results = []
failures = []
for idx, item in enumerate(items):
try:
result = self.call_with_retry(
process_func,
item,
exceptions=exceptions
)
results.append(result)
except Exception as e:
logger.error(f"处理第 {idx + 1} 项失败: {str(e)}")
failures.append({
"index": idx,
"item": item,
"error": str(e)
})
if not continue_on_failure:
raise
return results, failures