✨ Feature: Add feature: Support for counting model usage in the stats endpoint
Browse files
main.py
CHANGED
|
@@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 10 |
from fastapi import FastAPI, HTTPException, Depends, Request
|
| 11 |
from fastapi.responses import StreamingResponse, JSONResponse
|
| 12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
|
|
| 13 |
|
| 14 |
from models import RequestModel, ImageGenerationRequest
|
| 15 |
from request import get_payload
|
|
@@ -70,6 +71,14 @@ from datetime import timedelta
|
|
| 70 |
import json
|
| 71 |
import aiofiles
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
class StatsMiddleware(BaseHTTPMiddleware):
|
| 74 |
def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.json"):
|
| 75 |
super().__init__(app)
|
|
@@ -78,6 +87,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 78 |
self.ip_counts = defaultdict(lambda: defaultdict(int))
|
| 79 |
self.request_arrivals = defaultdict(list)
|
| 80 |
self.channel_success_counts = defaultdict(int)
|
|
|
|
| 81 |
self.channel_failure_counts = defaultdict(int)
|
| 82 |
self.lock = asyncio.Lock()
|
| 83 |
self.exclude_paths = set(exclude_paths or [])
|
|
@@ -91,6 +101,20 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 91 |
async def dispatch(self, request: Request, call_next):
|
| 92 |
arrival_time = datetime.now()
|
| 93 |
start_time = time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
response = await call_next(request)
|
| 95 |
process_time = time() - start_time
|
| 96 |
|
|
@@ -103,6 +127,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 103 |
self.request_times[endpoint] += process_time
|
| 104 |
self.ip_counts[endpoint][client_ip] += 1
|
| 105 |
self.request_arrivals[endpoint].append(arrival_time)
|
|
|
|
| 106 |
|
| 107 |
return response
|
| 108 |
|
|
@@ -121,6 +146,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 121 |
stats = {
|
| 122 |
"request_counts": dict(self.request_counts),
|
| 123 |
"request_times": dict(self.request_times),
|
|
|
|
| 124 |
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
|
| 125 |
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
|
| 126 |
"channel_success_counts": dict(self.channel_success_counts),
|
|
@@ -553,6 +579,7 @@ async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)
|
|
| 553 |
stats = {
|
| 554 |
"channel_success_percentages": middleware.calculate_success_percentages(),
|
| 555 |
"channel_failure_percentages": middleware.calculate_failure_percentages(),
|
|
|
|
| 556 |
"request_counts": dict(middleware.request_counts),
|
| 557 |
"request_times": dict(middleware.request_times),
|
| 558 |
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
|
|
|
|
| 10 |
from fastapi import FastAPI, HTTPException, Depends, Request
|
| 11 |
from fastapi.responses import StreamingResponse, JSONResponse
|
| 12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 13 |
+
from fastapi.exceptions import RequestValidationError
|
| 14 |
|
| 15 |
from models import RequestModel, ImageGenerationRequest
|
| 16 |
from request import get_payload
|
|
|
|
| 71 |
import json
|
| 72 |
import aiofiles
|
| 73 |
|
| 74 |
+
async def parse_request_body(request: Request):
|
| 75 |
+
if request.method == "POST" and "application/json" in request.headers.get("content-type", ""):
|
| 76 |
+
try:
|
| 77 |
+
return await request.json()
|
| 78 |
+
except json.JSONDecodeError:
|
| 79 |
+
return None
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
class StatsMiddleware(BaseHTTPMiddleware):
|
| 83 |
def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.json"):
|
| 84 |
super().__init__(app)
|
|
|
|
| 87 |
self.ip_counts = defaultdict(lambda: defaultdict(int))
|
| 88 |
self.request_arrivals = defaultdict(list)
|
| 89 |
self.channel_success_counts = defaultdict(int)
|
| 90 |
+
self.model_counts = defaultdict(int)
|
| 91 |
self.channel_failure_counts = defaultdict(int)
|
| 92 |
self.lock = asyncio.Lock()
|
| 93 |
self.exclude_paths = set(exclude_paths or [])
|
|
|
|
| 101 |
async def dispatch(self, request: Request, call_next):
|
| 102 |
arrival_time = datetime.now()
|
| 103 |
start_time = time()
|
| 104 |
+
|
| 105 |
+
# 使用依赖注入获取预解析的请求体
|
| 106 |
+
request.state.parsed_body = await parse_request_body(request)
|
| 107 |
+
|
| 108 |
+
model = "unknown"
|
| 109 |
+
if request.state.parsed_body:
|
| 110 |
+
try:
|
| 111 |
+
request_model = RequestModel(**request.state.parsed_body)
|
| 112 |
+
model = request_model.model
|
| 113 |
+
except RequestValidationError:
|
| 114 |
+
pass
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logger.error(f"Error processing request: {str(e)}")
|
| 117 |
+
|
| 118 |
response = await call_next(request)
|
| 119 |
process_time = time() - start_time
|
| 120 |
|
|
|
|
| 127 |
self.request_times[endpoint] += process_time
|
| 128 |
self.ip_counts[endpoint][client_ip] += 1
|
| 129 |
self.request_arrivals[endpoint].append(arrival_time)
|
| 130 |
+
self.model_counts[model] += 1
|
| 131 |
|
| 132 |
return response
|
| 133 |
|
|
|
|
| 146 |
stats = {
|
| 147 |
"request_counts": dict(self.request_counts),
|
| 148 |
"request_times": dict(self.request_times),
|
| 149 |
+
"model_counts": dict(self.model_counts),
|
| 150 |
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
|
| 151 |
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
|
| 152 |
"channel_success_counts": dict(self.channel_success_counts),
|
|
|
|
| 579 |
stats = {
|
| 580 |
"channel_success_percentages": middleware.calculate_success_percentages(),
|
| 581 |
"channel_failure_percentages": middleware.calculate_failure_percentages(),
|
| 582 |
+
"model_counts": dict(middleware.model_counts),
|
| 583 |
"request_counts": dict(middleware.request_counts),
|
| 584 |
"request_times": dict(middleware.request_times),
|
| 585 |
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
|