Spaces:
Sleeping
Sleeping
| import sqlite3 | |
| import time | |
| from datetime import datetime, timedelta | |
| from contextlib import contextmanager | |
| from fastapi import Request, Response | |
| import os | |
| DB_PATH = os.path.join(os.getenv("TMPDIR", "/tmp"), "monitor.db") | |
| def init_db(): | |
| """初始化监控数据库""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS api_calls ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| endpoint TEXT NOT NULL, | |
| ip_address TEXT NOT NULL, | |
| user_agent TEXT, | |
| status_code INTEGER, | |
| response_time REAL, | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| def get_db(): | |
| """数据库连接上下文管理器""" | |
| conn = sqlite3.connect(DB_PATH) | |
| try: | |
| yield conn | |
| finally: | |
| conn.close() | |
| def log_api_call(endpoint: str, ip: str, user_agent: str, status_code: int, response_time: float): | |
| """记录API调用""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO api_calls (endpoint, ip_address, user_agent, status_code, response_time) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (endpoint, ip, user_agent, status_code, response_time)) | |
| conn.commit() | |
| def get_stats(): | |
| """获取统计数据""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| # 总调用次数 | |
| cursor.execute("SELECT COUNT(*) FROM api_calls") | |
| total_calls = cursor.fetchone()[0] | |
| # 24小时内调用次数 | |
| yesterday = datetime.now() - timedelta(hours=24) | |
| cursor.execute("SELECT COUNT(*) FROM api_calls WHERE timestamp > ?", (yesterday,)) | |
| calls_24h = cursor.fetchone()[0] | |
| # 各端点统计 | |
| cursor.execute(""" | |
| SELECT endpoint, COUNT(*) as count, AVG(response_time) as avg_time | |
| FROM api_calls | |
| WHERE timestamp > ? | |
| GROUP BY endpoint | |
| ORDER BY count DESC | |
| """, (yesterday,)) | |
| endpoint_stats = cursor.fetchall() | |
| # 错误统计 | |
| cursor.execute(""" | |
| SELECT COUNT(*) FROM api_calls | |
| WHERE status_code >= 400 AND timestamp > ? | |
| """, (yesterday,)) | |
| error_count = cursor.fetchone()[0] | |
| return { | |
| "total_calls": total_calls, | |
| "calls_24h": calls_24h, | |
| "endpoint_stats": endpoint_stats, | |
| "error_count": error_count, | |
| "error_rate": round(error_count / max(calls_24h, 1) * 100, 2) | |
| } | |
| def setup_monitoring(app): | |
| """设置监控中间件""" | |
| init_db() | |
| async def monitor_middleware(request: Request, call_next): | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| # 记录API调用 | |
| log_api_call( | |
| endpoint=request.url.path, | |
| ip=request.client.host, | |
| user_agent=request.headers.get("user-agent", ""), | |
| status_code=response.status_code, | |
| response_time=process_time | |
| ) | |
| response.headers["X-Process-Time"] = str(process_time) | |
| return response |