File size: 3,325 Bytes
b17403a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import sqlite3
import time
from datetime import datetime, timedelta
from contextlib import contextmanager
from fastapi import Request, Response
import os

DB_PATH = os.path.join(os.getenv("TMPDIR", "/tmp"), "monitor.db")

def init_db():
    """初始化监控数据库"""
    conn = sqlite3.connect(DB_PATH)
    cursor = conn.cursor()
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS api_calls (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            endpoint TEXT NOT NULL,
            ip_address TEXT NOT NULL,
            user_agent TEXT,
            status_code INTEGER,
            response_time REAL,
            timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
        )
    """)
    conn.commit()
    conn.close()

@contextmanager
def get_db():
    """数据库连接上下文管理器"""
    conn = sqlite3.connect(DB_PATH)
    try:
        yield conn
    finally:
        conn.close()

def log_api_call(endpoint: str, ip: str, user_agent: str, status_code: int, response_time: float):
    """记录API调用"""
    with get_db() as conn:
        cursor = conn.cursor()
        cursor.execute("""
            INSERT INTO api_calls (endpoint, ip_address, user_agent, status_code, response_time)
            VALUES (?, ?, ?, ?, ?)
        """, (endpoint, ip, user_agent, status_code, response_time))
        conn.commit()

def get_stats():
    """获取统计数据"""
    with get_db() as conn:
        cursor = conn.cursor()
        
        # 总调用次数
        cursor.execute("SELECT COUNT(*) FROM api_calls")
        total_calls = cursor.fetchone()[0]
        
        # 24小时内调用次数
        yesterday = datetime.now() - timedelta(hours=24)
        cursor.execute("SELECT COUNT(*) FROM api_calls WHERE timestamp > ?", (yesterday,))
        calls_24h = cursor.fetchone()[0]
        
        # 各端点统计
        cursor.execute("""
            SELECT endpoint, COUNT(*) as count, AVG(response_time) as avg_time
            FROM api_calls 
            WHERE timestamp > ?
            GROUP BY endpoint
            ORDER BY count DESC
        """, (yesterday,))
        endpoint_stats = cursor.fetchall()
        
        # 错误统计
        cursor.execute("""
            SELECT COUNT(*) FROM api_calls 
            WHERE status_code >= 400 AND timestamp > ?
        """, (yesterday,))
        error_count = cursor.fetchone()[0]
        
        return {
            "total_calls": total_calls,
            "calls_24h": calls_24h,
            "endpoint_stats": endpoint_stats,
            "error_count": error_count,
            "error_rate": round(error_count / max(calls_24h, 1) * 100, 2)
        }

def setup_monitoring(app):
    """设置监控中间件"""
    init_db()
    
    @app.middleware("http")
    async def monitor_middleware(request: Request, call_next):
        start_time = time.time()
        response = await call_next(request)
        process_time = time.time() - start_time
        
        # 记录API调用
        log_api_call(
            endpoint=request.url.path,
            ip=request.client.host,
            user_agent=request.headers.get("user-agent", ""),
            status_code=response.status_code,
            response_time=process_time
        )
        
        response.headers["X-Process-Time"] = str(process_time)
        return response