Spaces:
Paused
Paused
| from flask import request, jsonify, g, session, redirect, url_for | |
| from functools import wraps | |
| from typing import Optional, Tuple | |
| import time | |
| import hashlib | |
| from collections import defaultdict | |
| import secrets | |
| from ..config.auth import auth_config | |
| from ..services.user_store import user_store, AuthResult | |
| # Rate limiting storage | |
| rate_limit_store = defaultdict(list) | |
| def extract_token_from_request() -> Optional[str]: | |
| """Extract authentication token from request headers""" | |
| # Check Authorization header (Bearer token) | |
| auth_header = request.headers.get('Authorization') | |
| if auth_header and auth_header.startswith('Bearer '): | |
| return auth_header[7:] # Remove "Bearer " prefix | |
| # Check x-api-key header | |
| api_key = request.headers.get('x-api-key') | |
| if api_key: | |
| return api_key | |
| # Check key query parameter | |
| key_param = request.args.get('key') | |
| if key_param: | |
| return key_param | |
| return None | |
| def get_client_ip() -> str: | |
| """Get client IP address from request""" | |
| # Check X-Forwarded-For header (for proxies) | |
| forwarded_for = request.headers.get('X-Forwarded-For') | |
| if forwarded_for: | |
| return forwarded_for.split(',')[0].strip() | |
| # Check X-Real-IP header (for nginx) | |
| real_ip = request.headers.get('X-Real-IP') | |
| if real_ip: | |
| return real_ip | |
| # Fallback to remote_addr | |
| return request.remote_addr or '127.0.0.1' | |
| def check_rate_limit(identifier: str, limit_per_minute: int = None) -> Tuple[bool, int]: | |
| """Check if identifier is within rate limit. Returns (allowed, remaining)""" | |
| if not auth_config.rate_limit_enabled: | |
| return True, limit_per_minute or auth_config.rate_limit_per_minute | |
| if limit_per_minute is None: | |
| limit_per_minute = auth_config.rate_limit_per_minute | |
| current_time = time.time() | |
| minute_ago = current_time - 60 | |
| # Clean old entries | |
| rate_limit_store[identifier] = [ | |
| timestamp for timestamp in rate_limit_store[identifier] | |
| if timestamp > minute_ago | |
| ] | |
| # Check if under limit | |
| current_count = len(rate_limit_store[identifier]) | |
| if current_count >= limit_per_minute: | |
| return False, 0 | |
| # Add current request | |
| rate_limit_store[identifier].append(current_time) | |
| return True, limit_per_minute - current_count - 1 | |
| def authenticate_request() -> Tuple[bool, Optional[str], Optional[dict]]: | |
| """ | |
| Authenticate request based on auth mode. | |
| Returns (success, error_message, user_data) | |
| """ | |
| client_ip = get_client_ip() | |
| # Check authentication mode | |
| if auth_config.mode == "none": | |
| # No authentication required | |
| return True, None, {"type": "none", "ip": client_ip} | |
| elif auth_config.mode == "proxy_key": | |
| # Single proxy password authentication | |
| token = extract_token_from_request() | |
| if not token: | |
| return False, "Authentication required: provide API key", None | |
| if token != auth_config.proxy_password: | |
| return False, "Invalid API key", None | |
| # Check rate limit by IP for proxy key mode | |
| allowed, remaining = check_rate_limit(client_ip) | |
| if not allowed: | |
| return False, "Rate limit exceeded", None | |
| return True, None, { | |
| "type": "proxy_key", | |
| "ip": client_ip, | |
| "rate_limit_remaining": remaining | |
| } | |
| elif auth_config.mode == "user_token": | |
| # Individual user token authentication | |
| token = extract_token_from_request() | |
| if not token: | |
| return False, "Authentication required: provide user token", None | |
| # Authenticate with user store | |
| auth_result, user = user_store.authenticate(token, client_ip) | |
| if auth_result == AuthResult.NOT_FOUND: | |
| return False, "Invalid user token", None | |
| elif auth_result == AuthResult.DISABLED: | |
| reason = user.disabled_reason or "Account disabled" | |
| return False, f"Account disabled: {reason}", None | |
| elif auth_result == AuthResult.LIMITED: | |
| return False, "IP address limit exceeded", None | |
| elif auth_result == AuthResult.SUCCESS: | |
| # Check rate limit by user token | |
| allowed, remaining = check_rate_limit(token) | |
| if not allowed: | |
| return False, "Rate limit exceeded", None | |
| return True, None, { | |
| "type": "user_token", | |
| "token": token, | |
| "user": user, | |
| "ip": client_ip, | |
| "rate_limit_remaining": remaining | |
| } | |
| return False, "Invalid authentication mode", None | |
| def require_auth(f): | |
| """Decorator to require authentication for endpoints""" | |
| def decorated_function(*args, **kwargs): | |
| success, error_message, user_data = authenticate_request() | |
| if not success: | |
| return jsonify({"error": error_message}), 401 | |
| # Store auth data in Flask g object for use in endpoint | |
| g.auth_data = user_data | |
| return f(*args, **kwargs) | |
| return decorated_function | |
| def require_admin_auth(f): | |
| """Decorator to require admin authentication""" | |
| def decorated_function(*args, **kwargs): | |
| # Check for admin key in Authorization header | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return jsonify({"error": "Admin authentication required"}), 401 | |
| admin_key = auth_header[7:] # Remove "Bearer " prefix | |
| if admin_key != auth_config.admin_key: | |
| return jsonify({"error": "Invalid admin key"}), 401 | |
| return f(*args, **kwargs) | |
| return decorated_function | |
| def check_quota(model_family: str) -> Tuple[bool, Optional[str]]: | |
| """Check if current user has quota for model family""" | |
| if not hasattr(g, 'auth_data') or g.auth_data["type"] != "user_token": | |
| return True, None # No quota limits for non-user-token auth | |
| token = g.auth_data["token"] | |
| has_quota, used, limit = user_store.check_quota(token, model_family) | |
| if not has_quota: | |
| return False, f"Quota exceeded for {model_family}: {used}/{limit} tokens used" | |
| return True, None | |
| def track_token_usage(model_family: str, input_tokens: int, output_tokens: int, cost: float = 0.0, response_time_ms: float = 0.0): | |
| """Track token usage for current user with enhanced tracking""" | |
| if not hasattr(g, 'auth_data') or g.auth_data["type"] != "user_token": | |
| return # No tracking for non-user-token auth | |
| token = g.auth_data["token"] | |
| user = g.auth_data["user"] | |
| ip_hash = hashlib.sha256(g.auth_data["ip"].encode()).hexdigest() | |
| user_agent = request.headers.get('User-Agent', '') | |
| # Use the enhanced tracking method | |
| user.add_request_tracking( | |
| model_family=model_family, | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| cost=cost, | |
| ip_hash=ip_hash, | |
| user_agent=user_agent | |
| ) | |
| # Also use the new structured event logger | |
| from ..services.firebase_logger import structured_logger | |
| structured_logger.log_chat_completion( | |
| user_token=token, | |
| model_family=model_family, | |
| model_name=f"{model_family}-model", | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| cost_usd=cost, | |
| response_time_ms=response_time_ms, | |
| success=True, | |
| ip_hash=ip_hash, | |
| user_agent=user_agent | |
| ) | |
| # Mark user for Firebase sync | |
| user_store.flush_queue.add(token) | |
| def require_admin_session(f): | |
| """Decorator to require admin session authentication for web interface""" | |
| def decorated_function(*args, **kwargs): | |
| if not session.get('admin_authenticated'): | |
| return redirect(url_for('admin.login')) | |
| return f(*args, **kwargs) | |
| return decorated_function | |
| def generate_csrf_token(): | |
| """Generate a CSRF token""" | |
| if 'csrf_token' not in session: | |
| session['csrf_token'] = secrets.token_urlsafe(32) | |
| return session['csrf_token'] | |
| def validate_csrf_token(token): | |
| """Validate CSRF token""" | |
| return token and token == session.get('csrf_token') | |
| def csrf_protect(f): | |
| """Decorator to protect against CSRF attacks""" | |
| def decorated_function(*args, **kwargs): | |
| if request.method == 'POST': | |
| token = request.form.get('_csrf') or request.headers.get('X-CSRFToken') | |
| if not validate_csrf_token(token): | |
| return jsonify({'error': 'CSRF token validation failed'}), 403 | |
| return f(*args, **kwargs) | |
| return decorated_function |