File size: 4,774 Bytes
7da14b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
041fce6
7da14b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Usage Tracking Middleware

Tracks user API usage for security and monitoring purposes.
Especially important for OpenAI API calls which cost money.
"""

import logging
import time
from typing import Dict, Any, Optional
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from utils.environment import is_huggingface_space
import json
from datetime import datetime

logger = logging.getLogger(__name__)


class UsageTrackingMiddleware(BaseHTTPMiddleware):
    """
    Middleware to track user API usage, especially for OpenAI-powered endpoints.
    """
    
    def __init__(self, app):
        super().__init__(app)
        
        # Endpoints that use OpenAI API (and thus cost money)
        self.openai_endpoints = [
            "/api/knowledge-graphs/extract",
            "/api/knowledge-graphs/analyze", 
            "/api/methods/",
            "/api/traces/analyze",
            "/api/causal/",
        ]
        
        # Endpoints that should be monitored for usage patterns
        self.monitored_endpoints = self.openai_endpoints + [
            "/api/traces/",
            "/api/tasks/",
            "/api/perturbation/",
        ]
    
    async def dispatch(self, request: Request, call_next):
        """Track API usage and log user activity."""
        start_time = time.time()
        
        # Get user info from request state (set by auth middleware)
        user = getattr(request.state, "user", None)
        user_id = user.get("username", "anonymous") if user else "anonymous"
        user_auth_method = user.get("auth_method", "none") if user else "none"
        
        # Track the request
        should_track = any(
            request.url.path.startswith(endpoint) 
            for endpoint in self.monitored_endpoints
        )
        
        is_openai_call = any(
            request.url.path.startswith(endpoint)
            for endpoint in self.openai_endpoints
        )
        
        # Log the request if it's being tracked
        if should_track:
            client_ip = request.client.host if request.client else "unknown"
            logger.info(
                f"📊 API Usage: {user_id} ({user_auth_method}) -> "
                f"{request.method} {request.url.path} from {client_ip} "
                f"{'💰 [OpenAI]' if is_openai_call else ''}"
            )
        
        # Process the request
        response = await call_next(request)
        
        # Calculate duration
        duration = time.time() - start_time
        
        # Log completion for important endpoints
        if should_track:
            status_emoji = "✅" if response.status_code < 400 else "❌"
            cost_warning = " 💸 COST INCURRED" if is_openai_call and response.status_code < 400 else ""
            
            logger.info(
                f"{status_emoji} API Complete: {user_id} -> "
                f"{request.method} {request.url.path} "
                f"[{response.status_code}] in {duration:.2f}s{cost_warning}"
            )
            
            # Log detailed usage for OpenAI calls
            if is_openai_call:
                self._log_openai_usage(user_id, user_auth_method, request, response, duration)
        
        return response
    
    def _log_openai_usage(
        self, 
        user_id: str, 
        auth_method: str, 
        request: Request, 
        response: Response, 
        duration: float
    ):
        """Log detailed information about OpenAI API usage."""
        
        usage_record = {
            "timestamp": datetime.now().isoformat(),
            "user_id": user_id,
            "auth_method": auth_method,
            "endpoint": request.url.path,
            "method": request.method,
            "status_code": response.status_code,
            "duration_seconds": round(duration, 2),
            "client_ip": request.client.host if request.client else "unknown",
            "user_agent": request.headers.get("User-Agent", "unknown"),
            "environment": "hf_spaces" if is_huggingface_space() else "local",
        }
        
        # Log as structured data for easy parsing/analysis
        logger.warning(
            f"💰 OPENAI_USAGE: {json.dumps(usage_record, separators=(',', ':'))}"
        )
        
        # Also log a human-readable summary
        if response.status_code >= 400:
            logger.error(
                f"🚨 OpenAI API Error: User {user_id} got {response.status_code} "
                f"on {request.url.path} - potential abuse or misconfiguration"
            )
        else:
            logger.info(
                f"💰 OpenAI API Success: User {user_id} used {request.url.path} "
                f"({duration:.2f}s) - track costs and usage patterns"
            )