File size: 8,859 Bytes
62be87c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05fe460
62be87c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
from flask import request, jsonify, g, session, redirect, url_for
from functools import wraps
from typing import Optional, Tuple
import time
import hashlib
from collections import defaultdict
import secrets

from ..config.auth import auth_config
from ..services.user_store import user_store, AuthResult

# Rate limiting storage
rate_limit_store = defaultdict(list)

def extract_token_from_request() -> Optional[str]:
    """Extract authentication token from request headers"""
    # Check Authorization header (Bearer token)
    auth_header = request.headers.get('Authorization')
    if auth_header and auth_header.startswith('Bearer '):
        return auth_header[7:]  # Remove "Bearer " prefix
    
    # Check x-api-key header
    api_key = request.headers.get('x-api-key')
    if api_key:
        return api_key
    
    # Check key query parameter
    key_param = request.args.get('key')
    if key_param:
        return key_param
    
    return None

def get_client_ip() -> str:
    """Get client IP address from request"""
    # Check X-Forwarded-For header (for proxies)
    forwarded_for = request.headers.get('X-Forwarded-For')
    if forwarded_for:
        return forwarded_for.split(',')[0].strip()
    
    # Check X-Real-IP header (for nginx)
    real_ip = request.headers.get('X-Real-IP')
    if real_ip:
        return real_ip
    
    # Fallback to remote_addr
    return request.remote_addr or '127.0.0.1'

def check_rate_limit(identifier: str, limit_per_minute: int = None) -> Tuple[bool, int]:
    """Check if identifier is within rate limit. Returns (allowed, remaining)"""
    if not auth_config.rate_limit_enabled:
        return True, limit_per_minute or auth_config.rate_limit_per_minute
    
    if limit_per_minute is None:
        limit_per_minute = auth_config.rate_limit_per_minute
    
    current_time = time.time()
    minute_ago = current_time - 60
    
    # Clean old entries
    rate_limit_store[identifier] = [
        timestamp for timestamp in rate_limit_store[identifier] 
        if timestamp > minute_ago
    ]
    
    # Check if under limit
    current_count = len(rate_limit_store[identifier])
    if current_count >= limit_per_minute:
        return False, 0
    
    # Add current request
    rate_limit_store[identifier].append(current_time)
    return True, limit_per_minute - current_count - 1

def authenticate_request() -> Tuple[bool, Optional[str], Optional[dict]]:
    """
    Authenticate request based on auth mode.
    Returns (success, error_message, user_data)
    """
    client_ip = get_client_ip()
    
    # Check authentication mode
    if auth_config.mode == "none":
        # No authentication required
        return True, None, {"type": "none", "ip": client_ip}
    
    elif auth_config.mode == "proxy_key":
        # Single proxy password authentication
        token = extract_token_from_request()
        
        if not token:
            return False, "Authentication required: provide API key", None
        
        if token != auth_config.proxy_password:
            return False, "Invalid API key", None
        
        # Check rate limit by IP for proxy key mode
        allowed, remaining = check_rate_limit(client_ip)
        if not allowed:
            return False, "Rate limit exceeded", None
        
        return True, None, {
            "type": "proxy_key",
            "ip": client_ip,
            "rate_limit_remaining": remaining
        }
    
    elif auth_config.mode == "user_token":
        # Individual user token authentication
        token = extract_token_from_request()
        
        if not token:
            return False, "Authentication required: provide user token", None
        
        # Authenticate with user store
        auth_result, user = user_store.authenticate(token, client_ip)
        
        if auth_result == AuthResult.NOT_FOUND:
            return False, "Invalid user token", None
        
        elif auth_result == AuthResult.DISABLED:
            reason = user.disabled_reason or "Account disabled"
            return False, f"Account disabled: {reason}", None
        
        elif auth_result == AuthResult.LIMITED:
            return False, "IP address limit exceeded", None
        
        elif auth_result == AuthResult.SUCCESS:
            # Check rate limit by user token
            allowed, remaining = check_rate_limit(token)
            if not allowed:
                return False, "Rate limit exceeded", None
            
            return True, None, {
                "type": "user_token",
                "token": token,
                "user": user,
                "ip": client_ip,
                "rate_limit_remaining": remaining
            }
    
    return False, "Invalid authentication mode", None

def require_auth(f):
    """Decorator to require authentication for endpoints"""
    @wraps(f)
    def decorated_function(*args, **kwargs):
        success, error_message, user_data = authenticate_request()
        
        if not success:
            return jsonify({"error": error_message}), 401
        
        # Store auth data in Flask g object for use in endpoint
        g.auth_data = user_data
        
        return f(*args, **kwargs)
    
    return decorated_function

def require_admin_auth(f):
    """Decorator to require admin authentication"""
    @wraps(f)
    def decorated_function(*args, **kwargs):
        # Check for admin key in Authorization header
        auth_header = request.headers.get('Authorization')
        if not auth_header or not auth_header.startswith('Bearer '):
            return jsonify({"error": "Admin authentication required"}), 401
        
        admin_key = auth_header[7:]  # Remove "Bearer " prefix
        
        if admin_key != auth_config.admin_key:
            return jsonify({"error": "Invalid admin key"}), 401
        
        return f(*args, **kwargs)
    
    return decorated_function

def check_quota(model_family: str) -> Tuple[bool, Optional[str]]:
    """Check if current user has quota for model family"""
    if not hasattr(g, 'auth_data') or g.auth_data["type"] != "user_token":
        return True, None  # No quota limits for non-user-token auth
    
    token = g.auth_data["token"]
    has_quota, used, limit = user_store.check_quota(token, model_family)
    
    if not has_quota:
        return False, f"Quota exceeded for {model_family}: {used}/{limit} tokens used"
    
    return True, None

def track_token_usage(model_family: str, input_tokens: int, output_tokens: int, cost: float = 0.0, response_time_ms: float = 0.0):
    """Track token usage for current user with enhanced tracking"""
    if not hasattr(g, 'auth_data') or g.auth_data["type"] != "user_token":
        return  # No tracking for non-user-token auth
    
    token = g.auth_data["token"]
    user = g.auth_data["user"]
    ip_hash = hashlib.sha256(g.auth_data["ip"].encode()).hexdigest()
    user_agent = request.headers.get('User-Agent', '')
    
    # Use the enhanced tracking method
    user.add_request_tracking(
        model_family=model_family,
        input_tokens=input_tokens,
        output_tokens=output_tokens,
        cost=cost,
        ip_hash=ip_hash,
        user_agent=user_agent
    )
    
    # Also use the new structured event logger
    from ..services.firebase_logger import structured_logger
    structured_logger.log_chat_completion(
        user_token=token,
        model_family=model_family,
        model_name=f"{model_family}-model",
        input_tokens=input_tokens,
        output_tokens=output_tokens,
        cost_usd=cost,
        response_time_ms=response_time_ms,
        success=True,
        ip_hash=ip_hash,
        user_agent=user_agent
    )
    
    # Mark user for Firebase sync
    user_store.flush_queue.add(token)

def require_admin_session(f):
    """Decorator to require admin session authentication for web interface"""
    @wraps(f)
    def decorated_function(*args, **kwargs):
        if not session.get('admin_authenticated'):
            return redirect(url_for('admin.login'))
        return f(*args, **kwargs)
    return decorated_function

def generate_csrf_token():
    """Generate a CSRF token"""
    if 'csrf_token' not in session:
        session['csrf_token'] = secrets.token_urlsafe(32)
    return session['csrf_token']

def validate_csrf_token(token):
    """Validate CSRF token"""
    return token and token == session.get('csrf_token')

def csrf_protect(f):
    """Decorator to protect against CSRF attacks"""
    @wraps(f)
    def decorated_function(*args, **kwargs):
        if request.method == 'POST':
            token = request.form.get('_csrf') or request.headers.get('X-CSRFToken')
            if not validate_csrf_token(token):
                return jsonify({'error': 'CSRF token validation failed'}), 403
        return f(*args, **kwargs)
    return decorated_function