import sqlite3 import time from datetime import datetime, timedelta from contextlib import contextmanager from fastapi import Request, Response import os DB_PATH = os.path.join(os.getenv("TMPDIR", "/tmp"), "monitor.db") def init_db(): """初始化监控数据库""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS api_calls ( id INTEGER PRIMARY KEY AUTOINCREMENT, endpoint TEXT NOT NULL, ip_address TEXT NOT NULL, user_agent TEXT, status_code INTEGER, response_time REAL, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) """) conn.commit() conn.close() @contextmanager def get_db(): """数据库连接上下文管理器""" conn = sqlite3.connect(DB_PATH) try: yield conn finally: conn.close() def log_api_call(endpoint: str, ip: str, user_agent: str, status_code: int, response_time: float): """记录API调用""" with get_db() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO api_calls (endpoint, ip_address, user_agent, status_code, response_time) VALUES (?, ?, ?, ?, ?) """, (endpoint, ip, user_agent, status_code, response_time)) conn.commit() def get_stats(): """获取统计数据""" with get_db() as conn: cursor = conn.cursor() # 总调用次数 cursor.execute("SELECT COUNT(*) FROM api_calls") total_calls = cursor.fetchone()[0] # 24小时内调用次数 yesterday = datetime.now() - timedelta(hours=24) cursor.execute("SELECT COUNT(*) FROM api_calls WHERE timestamp > ?", (yesterday,)) calls_24h = cursor.fetchone()[0] # 各端点统计 cursor.execute(""" SELECT endpoint, COUNT(*) as count, AVG(response_time) as avg_time FROM api_calls WHERE timestamp > ? GROUP BY endpoint ORDER BY count DESC """, (yesterday,)) endpoint_stats = cursor.fetchall() # 错误统计 cursor.execute(""" SELECT COUNT(*) FROM api_calls WHERE status_code >= 400 AND timestamp > ? """, (yesterday,)) error_count = cursor.fetchone()[0] return { "total_calls": total_calls, "calls_24h": calls_24h, "endpoint_stats": endpoint_stats, "error_count": error_count, "error_rate": round(error_count / max(calls_24h, 1) * 100, 2) } def setup_monitoring(app): """设置监控中间件""" init_db() @app.middleware("http") async def monitor_middleware(request: Request, call_next): start_time = time.time() response = await call_next(request) process_time = time.time() - start_time # 记录API调用 log_api_call( endpoint=request.url.path, ip=request.client.host, user_agent=request.headers.get("user-agent", ""), status_code=response.status_code, response_time=process_time ) response.headers["X-Process-Time"] = str(process_time) return response