""" Authentication Module for RAG API. Implements JWT-based authentication with rate limiting. """ import os import time import secrets from datetime import datetime, timedelta from typing import Optional, Dict from dataclasses import dataclass from functools import wraps from ..utils import get_logger logger = get_logger(__name__) @dataclass class User: """User model.""" user_id: str username: str email: str api_key: str created_at: datetime is_active: bool = True role: str = "user" # user, admin class JWTAuth: """ JWT-based authentication handler. """ def __init__( self, secret_key: Optional[str] = None, algorithm: str = "HS256", access_token_expire_minutes: int = 30 ): """ Initialize JWT authentication. Args: secret_key: Secret key for JWT signing algorithm: JWT algorithm access_token_expire_minutes: Token expiration time """ self.secret_key = secret_key or os.getenv("JWT_SECRET_KEY", secrets.token_hex(32)) self.algorithm = algorithm self.access_token_expire_minutes = access_token_expire_minutes # In-memory user store (replace with database in production) self.users: Dict[str, User] = {} self.api_keys: Dict[str, str] = {} # api_key -> user_id # Create default admin user self._create_default_admin() def _create_default_admin(self): """Create a default admin user.""" admin_key = os.getenv("ADMIN_API_KEY", "rag-admin-key-12345") admin = User( user_id="admin", username="admin", email="admin@localhost", api_key=admin_key, created_at=datetime.utcnow(), role="admin" ) self.users["admin"] = admin self.api_keys[admin_key] = "admin" logger.info("Default admin user created") def create_access_token(self, user_id: str, expires_delta: Optional[timedelta] = None) -> str: """ Create a JWT access token. Args: user_id: User identifier expires_delta: Token expiration time Returns: JWT token string """ try: import jwt except ImportError: logger.error("PyJWT not installed. Install with: pip install PyJWT") raise expire = datetime.utcnow() + (expires_delta or timedelta(minutes=self.access_token_expire_minutes)) payload = { "sub": user_id, "exp": expire, "iat": datetime.utcnow() } token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm) return token def verify_token(self, token: str) -> Optional[str]: """ Verify a JWT token and return user_id. Args: token: JWT token string Returns: User ID if valid, None otherwise """ try: import jwt payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) return payload.get("sub") except Exception as e: logger.debug(f"Token verification failed: {e}") return None def verify_api_key(self, api_key: str) -> Optional[User]: """ Verify an API key and return the user. Args: api_key: API key string Returns: User if valid, None otherwise """ user_id = self.api_keys.get(api_key) if user_id: return self.users.get(user_id) return None def create_user(self, username: str, email: str, role: str = "user") -> User: """ Create a new user with API key. Args: username: Username email: Email address role: User role Returns: Created user """ user_id = secrets.token_hex(8) api_key = secrets.token_urlsafe(32) user = User( user_id=user_id, username=username, email=email, api_key=api_key, created_at=datetime.utcnow(), role=role ) self.users[user_id] = user self.api_keys[api_key] = user_id logger.info(f"Created user: {username}") return user class RateLimiter: """ Simple in-memory rate limiter. Uses sliding window algorithm. """ def __init__( self, requests_per_minute: int = 60, requests_per_hour: int = 1000 ): """ Initialize rate limiter. Args: requests_per_minute: Max requests per minute requests_per_hour: Max requests per hour """ self.requests_per_minute = requests_per_minute self.requests_per_hour = requests_per_hour # Track requests: user_id -> list of timestamps self.requests: Dict[str, list] = {} def is_allowed(self, user_id: str) -> bool: """ Check if a request is allowed for the user. Args: user_id: User identifier Returns: True if allowed, False if rate limited """ now = time.time() if user_id not in self.requests: self.requests[user_id] = [] # Clean old requests minute_ago = now - 60 hour_ago = now - 3600 self.requests[user_id] = [ ts for ts in self.requests[user_id] if ts > hour_ago ] # Check limits recent_minute = sum(1 for ts in self.requests[user_id] if ts > minute_ago) recent_hour = len(self.requests[user_id]) if recent_minute >= self.requests_per_minute: logger.warning(f"Rate limit exceeded (minute) for {user_id}") return False if recent_hour >= self.requests_per_hour: logger.warning(f"Rate limit exceeded (hour) for {user_id}") return False # Record request self.requests[user_id].append(now) return True def get_remaining(self, user_id: str) -> Dict[str, int]: """ Get remaining requests for a user. Args: user_id: User identifier Returns: Dict with remaining requests """ now = time.time() minute_ago = now - 60 hour_ago = now - 3600 requests = self.requests.get(user_id, []) recent_minute = sum(1 for ts in requests if ts > minute_ago) recent_hour = sum(1 for ts in requests if ts > hour_ago) return { "minute_remaining": max(0, self.requests_per_minute - recent_minute), "hour_remaining": max(0, self.requests_per_hour - recent_hour) } # Global instances _auth: Optional[JWTAuth] = None _rate_limiter: Optional[RateLimiter] = None def get_auth() -> JWTAuth: """Get global auth instance.""" global _auth if _auth is None: _auth = JWTAuth() return _auth def get_rate_limiter() -> RateLimiter: """Get global rate limiter.""" global _rate_limiter if _rate_limiter is None: _rate_limiter = RateLimiter() return _rate_limiter