GAP / app /core /utils /request_helpers.py
misonL's picture
Upload 52 files
e82bac2 verified
Raw
History Blame Contribute Delete
8.07 kB
# -*- coding: utf-8 -*-
"""
请求处理相关的辅助函数。
包含获取客户端 IP、基于 IP 的速率限制(可能已废弃)以及获取时间戳的功能。
"""
import time # 导入时间模块,用于获取时间戳
import logging # 导入日志模块
from datetime import datetime, timedelta, timezone # 导入日期时间处理相关类
import pytz # 导入时区库,用于处理太平洋时区
from fastapi import Request, HTTPException # 导入 FastAPI 的请求对象和 HTTP 异常类
from collections import defaultdict # 导入默认字典,方便初始化嵌套字典
from threading import Lock # 导入线程锁,用于保护共享数据
from typing import Dict, Union, List, Tuple # 导入类型提示
# 获取名为 'my_logger' 的日志记录器实例
logger = logging.getLogger("my_logger")
# --- IP 地址和滥用防护 ---
# 用于存储每个 IP 地址的请求时间戳和每日计数 (可能已废弃或需要与 Key 级别限制协调)
# 结构: {'ip_address': {'daily_count': int, 'timestamps': List[float]}}
ip_request_data: Dict[str, Dict[str, Union[int, List[float]]]] = defaultdict(lambda: {"daily_count": 0, "timestamps": []})
# 用于保护 ip_request_data 访问的线程锁
ip_daily_counts_lock = Lock()
# 获取太平洋时区对象,用于每日计数重置
pacific_tz = pytz.timezone('America/Los_Angeles')
def get_client_ip(request: Request) -> str:
"""
从 FastAPI 请求对象中提取真实的客户端 IP 地址。
优先检查常见的代理头 ('X-Forwarded-For', 'X-Real-IP'),
如果不存在,则回退到直接连接的客户端地址 (`request.client.host`)。
Args:
request (Request): FastAPI 请求对象。
Returns:
str: 客户端 IP 地址字符串。如果无法确定,则返回 "Unknown"。
"""
# 检查 'X-Forwarded-For' 头,通常由代理服务器添加
x_forwarded_for = request.headers.get("x-forwarded-for")
if x_forwarded_for:
# 'X-Forwarded-For' 可能包含多个 IP (client, proxy1, proxy2, ...),取第一个通常是原始客户端 IP
ip = x_forwarded_for.split(",")[0].strip() # 按逗号分割并取第一个,去除空白
logger.debug(f"从 X-Forwarded-For 获取 IP: {ip}") # 记录日志
return ip
# 检查 'X-Real-IP' 头,一些反向代理(如 Nginx)会使用这个头
x_real_ip = request.headers.get("x-real-ip")
if x_real_ip:
logger.debug(f"从 X-Real-IP 获取 IP: {x_real_ip}") # 记录日志
return x_real_ip
# 如果代理头都不存在,回退到 FastAPI 提供的直接连接客户端信息
# request.client 包含 (host, port) 元组
client_host = request.client.host if request.client else None
if client_host:
logger.debug(f"从 request.client.host 获取 IP: {client_host}") # 记录日志
return client_host
else: # 如果无法获取任何 IP 信息
logger.warning("无法从请求头或 client 属性中获取客户端 IP 地址。") # 记录警告
return "Unknown" # 返回 "Unknown"
def protect_from_abuse(request: Request, max_rpm: int, max_rpd: int):
"""
(可能已废弃/需要审查) 基于 IP 地址实现简单的请求速率 (RPM) 和每日总量 (RPD) 限制。
注意:此函数使用全局字典 `ip_request_data` 存储状态,可能与基于 Key 的限制冲突或重复。
在当前系统中,速率限制主要由 Key Manager 处理,此函数可能不再需要或需要重新设计。
Args:
request (Request): FastAPI 请求对象。
max_rpm (int): 每个 IP 允许的最大每分钟请求数。
max_rpd (int): 每个 IP 允许的最大每日请求数。
Raises:
HTTPException (429 Too Many Requests): 如果检测到超过速率或每日限制。
"""
logger.warning("调用了可能已废弃的 protect_from_abuse 函数 (基于 IP 的速率限制)。") # 记录警告
global ip_request_data, ip_daily_counts_lock, pacific_tz # 声明使用全局变量
ip = get_client_ip(request) # 获取客户端 IP
if ip == "Unknown": # 如果无法获取 IP
logger.warning("无法获取客户端 IP 地址,跳过滥用检查。") # 记录警告并跳过检查
return
now = time.time() # 获取当前时间戳
today_pacific = datetime.now(pacific_tz).date() # 获取当前太平洋时区的日期
with ip_daily_counts_lock: # 获取 IP 计数锁,保证线程安全
ip_data = ip_request_data[ip] # 获取或创建该 IP 的数据字典
daily_count = ip_data.get("daily_count", 0) # 获取当日请求计数
timestamps: List[float] = ip_data.get("timestamps", []) # 获取请求时间戳列表
# --- 检查每日请求限制 (RPD) ---
# 获取列表中最早的时间戳对应的日期 (如果列表不为空)
last_request_date_pacific = datetime.fromtimestamp(timestamps[0], pacific_tz).date() if timestamps else None
# 如果上次请求不是今天 (太平洋时间),则重置每日计数
if last_request_date_pacific != today_pacific:
daily_count = 0 # 重置计数
timestamps = [] # 清空时间戳列表 (因为 RPM 检查也基于此列表)
ip_data["daily_count"] = 0 # 更新存储的计数
ip_data["timestamps"] = [] # 更新存储的时间戳列表
logger.debug(f"IP {ip} 的每日计数已重置 (新的一天)。") # 记录日志
# 检查加上当前请求是否超过每日限制
if daily_count >= max_rpd:
logger.warning(f"IP {ip} 已达到每日请求限制 ({max_rpd} RPD)。") # 记录警告
# 抛出 429 错误
raise HTTPException(status_code=429, detail=f"您已达到每日请求限制 ({max_rpd} RPD)。请明天再试。")
# --- 检查每分钟请求限制 (RPM) ---
rpm_window_seconds = 60 # 定义 RPM 的时间窗口为 60 秒
# 移除时间戳列表中所有早于 (当前时间 - 窗口时长) 的时间戳
timestamps = [ts for ts in timestamps if now - ts < rpm_window_seconds]
# 检查剩余时间戳的数量是否达到或超过 RPM 限制
if len(timestamps) >= max_rpm:
# 计算需要等待的时间:窗口时长 - (当前时间 - 最早的时间戳)
earliest_timestamp = timestamps[0] if timestamps else now # 获取窗口内最早的时间戳
wait_time = max(0.0, earliest_timestamp + rpm_window_seconds - now) # 计算剩余等待时间,确保不为负
logger.warning(f"IP {ip} 请求过于频繁,触发 RPM 限制 ({max_rpm} RPM)。需要等待 {wait_time:.2f} 秒。") # 记录警告
# 抛出 429 错误
raise HTTPException(status_code=429, detail=f"请求过于频繁。请在 {wait_time:.2f} 秒后重试。")
# --- 更新计数和时间戳 ---
# 如果检查通过,则更新计数和时间戳列表
ip_data["daily_count"] = daily_count + 1 # 每日计数加 1
timestamps.append(now) # 将当前时间戳添加到列表末尾
ip_data["timestamps"] = timestamps # 更新存储的时间戳列表
logger.debug(f"IP {ip} 请求计数更新: RPD={ip_data['daily_count']}, RPM_Window_Count={len(timestamps)}") # 记录调试日志
def get_current_timestamps() -> Tuple[float, str]: # 返回类型修改为 Tuple[float, str]
"""
获取当前的 Unix 时间戳和太平洋时区的日期字符串。
Returns:
Tuple[float, str]:
- 第一个元素:当前的 Unix 时间戳 (float)。
- 第二个元素:太平洋时区的当前日期字符串 (str, 'YYYY-MM-DD' 格式)。
"""
now_timestamp = time.time() # 获取当前 Unix 时间戳
# 获取当前太平洋时区的日期并格式化为 ISO 格式字符串 ('YYYY-MM-DD')
today_pacific = datetime.now(pacific_tz).date()
today_date_str_pt = today_pacific.isoformat()
return now_timestamp, today_date_str_pt # 返回元组