| """ |
| API Security Module |
| |
| This module provides security features for the API, including: |
| 1. Authentication using JWT tokens |
| 2. Rate limiting to prevent abuse |
| 3. Role-based access control |
| 4. Request validation |
| 5. Audit logging |
| """ |
| import os |
| import time |
| import logging |
| import secrets |
| from datetime import datetime, timedelta |
| from typing import Dict, List, Optional, Union, Any, Callable |
|
|
| from fastapi import Depends, HTTPException, Security, status, Request |
| from fastapi.security import OAuth2PasswordBearer, APIKeyHeader |
| from jose import JWTError, jwt |
| from passlib.context import CryptContext |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from sqlalchemy.future import select |
| from pydantic import BaseModel, EmailStr |
|
|
| from src.models.user import User |
| from src.api.database import get_db |
|
|
| |
| logger = logging.getLogger(__name__) |
|
|
| |
| SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_hex(32)) |
| ALGORITHM = "HS256" |
| ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
| API_KEY_NAME = "X-API-Key" |
|
|
| |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
| |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
| api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) |
|
|
| |
| class Token(BaseModel): |
| access_token: str |
| token_type: str |
| expires_at: datetime |
|
|
| class TokenData(BaseModel): |
| username: Optional[str] = None |
| scopes: List[str] = [] |
|
|
| class UserInDB(BaseModel): |
| id: int |
| username: str |
| email: EmailStr |
| full_name: Optional[str] = None |
| is_active: bool = True |
| is_superuser: bool = False |
| scopes: List[str] = [] |
| |
| class Config: |
| from_attributes = True |
|
|
| |
| class RateLimiter: |
| """Simple in-memory rate limiter""" |
| |
| def __init__(self, rate_limit: int = 100, time_window: int = 60): |
| """ |
| Initialize rate limiter. |
| |
| Args: |
| rate_limit: Maximum number of requests per time window |
| time_window: Time window in seconds |
| """ |
| self.rate_limit = rate_limit |
| self.time_window = time_window |
| self.requests = {} |
| |
| def is_rate_limited(self, key: str) -> bool: |
| """ |
| Check if a key is rate limited. |
| |
| Args: |
| key: Identifier for the client (IP address, API key, etc.) |
| |
| Returns: |
| True if rate limited, False otherwise |
| """ |
| current_time = time.time() |
| |
| |
| if key not in self.requests: |
| self.requests[key] = [] |
| else: |
| |
| self.requests[key] = [t for t in self.requests[key] if t > current_time - self.time_window] |
| |
| |
| if len(self.requests[key]) >= self.rate_limit: |
| return True |
| |
| |
| self.requests[key].append(current_time) |
| return False |
|
|
| |
| rate_limiter = RateLimiter() |
|
|
| |
| |
| ROLES = { |
| "admin": ["read:all", "write:all", "delete:all"], |
| "analyst": ["read:all", "write:threats", "write:indicators", "write:reports"], |
| "user": ["read:threats", "read:reports", "read:dashboard"], |
| "api": ["read:all", "write:threats", "write:indicators"] |
| } |
|
|
| |
| def verify_password(plain_password: str, hashed_password: str) -> bool: |
| """Verify a password against a hash""" |
| return pwd_context.verify(plain_password, hashed_password) |
|
|
| def get_password_hash(password: str) -> str: |
| """Hash a password for storage""" |
| return pwd_context.hash(password) |
|
|
| async def get_user(db: AsyncSession, username: str) -> Optional[UserInDB]: |
| """Get a user from the database by username""" |
| result = await db.execute(select(User).filter(User.username == username)) |
| user_db = result.scalars().first() |
| |
| if not user_db: |
| return None |
| |
| |
| scopes = [] |
| if user_db.is_superuser: |
| scopes = ROLES["admin"] |
| else: |
| |
| |
| scopes = ROLES["user"] |
| |
| return UserInDB( |
| id=user_db.id, |
| username=user_db.username, |
| email=user_db.email, |
| full_name=user_db.full_name, |
| is_active=user_db.is_active, |
| is_superuser=user_db.is_superuser, |
| scopes=scopes |
| ) |
|
|
| async def authenticate_user(db: AsyncSession, username: str, password: str) -> Optional[UserInDB]: |
| """Authenticate a user with username and password""" |
| user = await get_user(db, username) |
| if not user: |
| return None |
| |
| |
| result = await db.execute(select(User).filter(User.username == username)) |
| user_db = result.scalars().first() |
| |
| if not verify_password(password, user_db.hashed_password): |
| return None |
| |
| return user |
|
|
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: |
| """Create a JWT access token""" |
| to_encode = data.copy() |
| |
| if expires_delta: |
| expire = datetime.utcnow() + expires_delta |
| else: |
| expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
| |
| to_encode.update({"exp": expire}) |
| encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
| return encoded_jwt |
|
|
| async def get_api_key_user( |
| api_key: str, |
| db: AsyncSession |
| ) -> Optional[UserInDB]: |
| """Get user associated with an API key""" |
| |
| |
| |
| API_KEYS = { |
| "test-api-key": "api_user", |
| |
| } |
| |
| if api_key not in API_KEYS: |
| return None |
| |
| username = API_KEYS[api_key] |
| user = await get_user(db, username) |
| |
| if not user: |
| return None |
| |
| |
| user.scopes = ROLES["api"] |
| |
| return user |
|
|
| |
| async def rate_limit(request: Request): |
| """Rate limiting dependency""" |
| |
| client_key = request.headers.get(API_KEY_NAME) or request.client.host |
| |
| if rate_limiter.is_rate_limited(client_key): |
| logger.warning(f"Rate limit exceeded for {client_key}") |
| raise HTTPException( |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, |
| detail="Rate limit exceeded. Please try again later." |
| ) |
| |
| return True |
|
|
| async def get_current_user( |
| token: str = Depends(oauth2_scheme), |
| api_key: str = Security(api_key_header), |
| db: AsyncSession = Depends(get_db) |
| ) -> UserInDB: |
| """ |
| Get the current user from either JWT token or API key. |
| |
| This dependency can be used to require authentication for endpoints. |
| """ |
| credentials_exception = HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Could not validate credentials", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
| |
| |
| if api_key: |
| user = await get_api_key_user(api_key, db) |
| if user: |
| return user |
| |
| |
| try: |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| username = payload.get("sub") |
| if username is None: |
| raise credentials_exception |
| |
| token_data = TokenData( |
| username=username, |
| scopes=payload.get("scopes", []) |
| ) |
| except JWTError: |
| raise credentials_exception |
| |
| user = await get_user(db, username=token_data.username) |
| if user is None: |
| raise credentials_exception |
| |
| return user |
|
|
| async def get_current_active_user( |
| current_user: UserInDB = Depends(get_current_user) |
| ) -> UserInDB: |
| """ |
| Get the current active user. |
| |
| This dependency can be used to require an active user for endpoints. |
| """ |
| if not current_user.is_active: |
| raise HTTPException(status_code=400, detail="Inactive user") |
| |
| return current_user |
|
|
| def has_scope(required_scopes: List[str]): |
| """ |
| Create a dependency that checks if the user has the required scopes. |
| |
| Args: |
| required_scopes: List of required scopes |
| |
| Returns: |
| A dependency function that checks if the user has the required scopes |
| """ |
| async def _has_scope( |
| current_user: UserInDB = Depends(get_current_active_user) |
| ) -> UserInDB: |
| for scope in required_scopes: |
| if scope not in current_user.scopes: |
| raise HTTPException( |
| status_code=status.HTTP_403_FORBIDDEN, |
| detail=f"Permission denied. Required scope: {scope}" |
| ) |
| |
| return current_user |
| |
| return _has_scope |
|
|
| def admin_only( |
| current_user: UserInDB = Depends(get_current_active_user) |
| ) -> UserInDB: |
| """ |
| Dependency that requires an admin user. |
| """ |
| if not current_user.is_superuser: |
| raise HTTPException( |
| status_code=status.HTTP_403_FORBIDDEN, |
| detail="Permission denied. Admin access required." |
| ) |
| |
| return current_user |
|
|
| |
| async def audit_log_middleware(request: Request, call_next): |
| """ |
| Middleware for audit logging. |
| |
| Records details about API requests. |
| """ |
| |
| method = request.method |
| path = request.url.path |
| client_host = request.client.host |
| user_agent = request.headers.get("User-Agent", "Unknown") |
| |
| |
| user = getattr(request.state, "user", None) |
| username = user.username if user else "Anonymous" |
| |
| |
| logger.info( |
| f"API Request: {method} {path} | User: {username} | " |
| f"Client: {client_host} | User-Agent: {user_agent}" |
| ) |
| |
| |
| start_time = time.time() |
| response = await call_next(request) |
| process_time = time.time() - start_time |
| |
| |
| logger.info( |
| f"API Response: {method} {path} | Status: {response.status_code} | " |
| f"Time: {process_time:.4f}s | User: {username}" |
| ) |
| |
| return response |
|
|
| |
| def validate_api_key(request: Request): |
| """ |
| Middleware function to validate API keys. |
| |
| This can be used as a dependency for FastAPI routes. |
| """ |
| api_key = request.headers.get(API_KEY_NAME) |
| if not api_key: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="API key required", |
| headers={"WWW-Authenticate": f"{API_KEY_NAME}"}, |
| ) |
| |
| |
| |
| valid_keys = ["test-api-key"] |
| if api_key not in valid_keys: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid API key", |
| headers={"WWW-Authenticate": f"{API_KEY_NAME}"}, |
| ) |
| |
| return True |