WasabiDrop's picture
🧹 Clean up local database files and emphasize Firebase requirement
05fe460
from flask import request, jsonify, g, session, redirect, url_for
from functools import wraps
from typing import Optional, Tuple
import time
import hashlib
from collections import defaultdict
import secrets
from ..config.auth import auth_config
from ..services.user_store import user_store, AuthResult
# Rate limiting storage
rate_limit_store = defaultdict(list)
def extract_token_from_request() -> Optional[str]:
"""Extract authentication token from request headers"""
# Check Authorization header (Bearer token)
auth_header = request.headers.get('Authorization')
if auth_header and auth_header.startswith('Bearer '):
return auth_header[7:] # Remove "Bearer " prefix
# Check x-api-key header
api_key = request.headers.get('x-api-key')
if api_key:
return api_key
# Check key query parameter
key_param = request.args.get('key')
if key_param:
return key_param
return None
def get_client_ip() -> str:
"""Get client IP address from request"""
# Check X-Forwarded-For header (for proxies)
forwarded_for = request.headers.get('X-Forwarded-For')
if forwarded_for:
return forwarded_for.split(',')[0].strip()
# Check X-Real-IP header (for nginx)
real_ip = request.headers.get('X-Real-IP')
if real_ip:
return real_ip
# Fallback to remote_addr
return request.remote_addr or '127.0.0.1'
def check_rate_limit(identifier: str, limit_per_minute: int = None) -> Tuple[bool, int]:
"""Check if identifier is within rate limit. Returns (allowed, remaining)"""
if not auth_config.rate_limit_enabled:
return True, limit_per_minute or auth_config.rate_limit_per_minute
if limit_per_minute is None:
limit_per_minute = auth_config.rate_limit_per_minute
current_time = time.time()
minute_ago = current_time - 60
# Clean old entries
rate_limit_store[identifier] = [
timestamp for timestamp in rate_limit_store[identifier]
if timestamp > minute_ago
]
# Check if under limit
current_count = len(rate_limit_store[identifier])
if current_count >= limit_per_minute:
return False, 0
# Add current request
rate_limit_store[identifier].append(current_time)
return True, limit_per_minute - current_count - 1
def authenticate_request() -> Tuple[bool, Optional[str], Optional[dict]]:
"""
Authenticate request based on auth mode.
Returns (success, error_message, user_data)
"""
client_ip = get_client_ip()
# Check authentication mode
if auth_config.mode == "none":
# No authentication required
return True, None, {"type": "none", "ip": client_ip}
elif auth_config.mode == "proxy_key":
# Single proxy password authentication
token = extract_token_from_request()
if not token:
return False, "Authentication required: provide API key", None
if token != auth_config.proxy_password:
return False, "Invalid API key", None
# Check rate limit by IP for proxy key mode
allowed, remaining = check_rate_limit(client_ip)
if not allowed:
return False, "Rate limit exceeded", None
return True, None, {
"type": "proxy_key",
"ip": client_ip,
"rate_limit_remaining": remaining
}
elif auth_config.mode == "user_token":
# Individual user token authentication
token = extract_token_from_request()
if not token:
return False, "Authentication required: provide user token", None
# Authenticate with user store
auth_result, user = user_store.authenticate(token, client_ip)
if auth_result == AuthResult.NOT_FOUND:
return False, "Invalid user token", None
elif auth_result == AuthResult.DISABLED:
reason = user.disabled_reason or "Account disabled"
return False, f"Account disabled: {reason}", None
elif auth_result == AuthResult.LIMITED:
return False, "IP address limit exceeded", None
elif auth_result == AuthResult.SUCCESS:
# Check rate limit by user token
allowed, remaining = check_rate_limit(token)
if not allowed:
return False, "Rate limit exceeded", None
return True, None, {
"type": "user_token",
"token": token,
"user": user,
"ip": client_ip,
"rate_limit_remaining": remaining
}
return False, "Invalid authentication mode", None
def require_auth(f):
"""Decorator to require authentication for endpoints"""
@wraps(f)
def decorated_function(*args, **kwargs):
success, error_message, user_data = authenticate_request()
if not success:
return jsonify({"error": error_message}), 401
# Store auth data in Flask g object for use in endpoint
g.auth_data = user_data
return f(*args, **kwargs)
return decorated_function
def require_admin_auth(f):
"""Decorator to require admin authentication"""
@wraps(f)
def decorated_function(*args, **kwargs):
# Check for admin key in Authorization header
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Admin authentication required"}), 401
admin_key = auth_header[7:] # Remove "Bearer " prefix
if admin_key != auth_config.admin_key:
return jsonify({"error": "Invalid admin key"}), 401
return f(*args, **kwargs)
return decorated_function
def check_quota(model_family: str) -> Tuple[bool, Optional[str]]:
"""Check if current user has quota for model family"""
if not hasattr(g, 'auth_data') or g.auth_data["type"] != "user_token":
return True, None # No quota limits for non-user-token auth
token = g.auth_data["token"]
has_quota, used, limit = user_store.check_quota(token, model_family)
if not has_quota:
return False, f"Quota exceeded for {model_family}: {used}/{limit} tokens used"
return True, None
def track_token_usage(model_family: str, input_tokens: int, output_tokens: int, cost: float = 0.0, response_time_ms: float = 0.0):
"""Track token usage for current user with enhanced tracking"""
if not hasattr(g, 'auth_data') or g.auth_data["type"] != "user_token":
return # No tracking for non-user-token auth
token = g.auth_data["token"]
user = g.auth_data["user"]
ip_hash = hashlib.sha256(g.auth_data["ip"].encode()).hexdigest()
user_agent = request.headers.get('User-Agent', '')
# Use the enhanced tracking method
user.add_request_tracking(
model_family=model_family,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
ip_hash=ip_hash,
user_agent=user_agent
)
# Also use the new structured event logger
from ..services.firebase_logger import structured_logger
structured_logger.log_chat_completion(
user_token=token,
model_family=model_family,
model_name=f"{model_family}-model",
input_tokens=input_tokens,
output_tokens=output_tokens,
cost_usd=cost,
response_time_ms=response_time_ms,
success=True,
ip_hash=ip_hash,
user_agent=user_agent
)
# Mark user for Firebase sync
user_store.flush_queue.add(token)
def require_admin_session(f):
"""Decorator to require admin session authentication for web interface"""
@wraps(f)
def decorated_function(*args, **kwargs):
if not session.get('admin_authenticated'):
return redirect(url_for('admin.login'))
return f(*args, **kwargs)
return decorated_function
def generate_csrf_token():
"""Generate a CSRF token"""
if 'csrf_token' not in session:
session['csrf_token'] = secrets.token_urlsafe(32)
return session['csrf_token']
def validate_csrf_token(token):
"""Validate CSRF token"""
return token and token == session.get('csrf_token')
def csrf_protect(f):
"""Decorator to protect against CSRF attacks"""
@wraps(f)
def decorated_function(*args, **kwargs):
if request.method == 'POST':
token = request.form.get('_csrf') or request.headers.get('X-CSRFToken')
if not validate_csrf_token(token):
return jsonify({'error': 'CSRF token validation failed'}), 403
return f(*args, **kwargs)
return decorated_function