AgentGraph / backend /middleware /usage_tracker.py
wu981526092's picture
add
041fce6
"""
Usage Tracking Middleware
Tracks user API usage for security and monitoring purposes.
Especially important for OpenAI API calls which cost money.
"""
import logging
import time
from typing import Dict, Any, Optional
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from utils.environment import is_huggingface_space
import json
from datetime import datetime
logger = logging.getLogger(__name__)
class UsageTrackingMiddleware(BaseHTTPMiddleware):
"""
Middleware to track user API usage, especially for OpenAI-powered endpoints.
"""
def __init__(self, app):
super().__init__(app)
# Endpoints that use OpenAI API (and thus cost money)
self.openai_endpoints = [
"/api/knowledge-graphs/extract",
"/api/knowledge-graphs/analyze",
"/api/methods/",
"/api/traces/analyze",
"/api/causal/",
]
# Endpoints that should be monitored for usage patterns
self.monitored_endpoints = self.openai_endpoints + [
"/api/traces/",
"/api/tasks/",
"/api/perturbation/",
]
async def dispatch(self, request: Request, call_next):
"""Track API usage and log user activity."""
start_time = time.time()
# Get user info from request state (set by auth middleware)
user = getattr(request.state, "user", None)
user_id = user.get("username", "anonymous") if user else "anonymous"
user_auth_method = user.get("auth_method", "none") if user else "none"
# Track the request
should_track = any(
request.url.path.startswith(endpoint)
for endpoint in self.monitored_endpoints
)
is_openai_call = any(
request.url.path.startswith(endpoint)
for endpoint in self.openai_endpoints
)
# Log the request if it's being tracked
if should_track:
client_ip = request.client.host if request.client else "unknown"
logger.info(
f"πŸ“Š API Usage: {user_id} ({user_auth_method}) -> "
f"{request.method} {request.url.path} from {client_ip} "
f"{'πŸ’° [OpenAI]' if is_openai_call else ''}"
)
# Process the request
response = await call_next(request)
# Calculate duration
duration = time.time() - start_time
# Log completion for important endpoints
if should_track:
status_emoji = "βœ…" if response.status_code < 400 else "❌"
cost_warning = " πŸ’Έ COST INCURRED" if is_openai_call and response.status_code < 400 else ""
logger.info(
f"{status_emoji} API Complete: {user_id} -> "
f"{request.method} {request.url.path} "
f"[{response.status_code}] in {duration:.2f}s{cost_warning}"
)
# Log detailed usage for OpenAI calls
if is_openai_call:
self._log_openai_usage(user_id, user_auth_method, request, response, duration)
return response
def _log_openai_usage(
self,
user_id: str,
auth_method: str,
request: Request,
response: Response,
duration: float
):
"""Log detailed information about OpenAI API usage."""
usage_record = {
"timestamp": datetime.now().isoformat(),
"user_id": user_id,
"auth_method": auth_method,
"endpoint": request.url.path,
"method": request.method,
"status_code": response.status_code,
"duration_seconds": round(duration, 2),
"client_ip": request.client.host if request.client else "unknown",
"user_agent": request.headers.get("User-Agent", "unknown"),
"environment": "hf_spaces" if is_huggingface_space() else "local",
}
# Log as structured data for easy parsing/analysis
logger.warning(
f"πŸ’° OPENAI_USAGE: {json.dumps(usage_record, separators=(',', ':'))}"
)
# Also log a human-readable summary
if response.status_code >= 400:
logger.error(
f"🚨 OpenAI API Error: User {user_id} got {response.status_code} "
f"on {request.url.path} - potential abuse or misconfiguration"
)
else:
logger.info(
f"πŸ’° OpenAI API Success: User {user_id} used {request.url.path} "
f"({duration:.2f}s) - track costs and usage patterns"
)