|
|
""" |
|
|
Enterprise API Key Authentication System for MCP Servers |
|
|
|
|
|
Features: |
|
|
- API key generation and validation |
|
|
- Key rotation support |
|
|
- Expiry and rate limiting per key |
|
|
- Audit logging of authentication attempts |
|
|
- Multiple authentication methods (header, query param) |
|
|
""" |
|
|
import os |
|
|
import secrets |
|
|
import hashlib |
|
|
import hmac |
|
|
import logging |
|
|
from typing import Optional, Dict, Set, Tuple |
|
|
from datetime import datetime, timedelta |
|
|
from dataclasses import dataclass |
|
|
from aiohttp import web |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class APIKey: |
|
|
"""API Key with metadata""" |
|
|
key_id: str |
|
|
key_hash: str |
|
|
name: str |
|
|
tenant_id: Optional[str] = None |
|
|
created_at: datetime = None |
|
|
expires_at: Optional[datetime] = None |
|
|
is_active: bool = True |
|
|
permissions: Set[str] = None |
|
|
rate_limit: int = 100 |
|
|
metadata: Dict = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.created_at is None: |
|
|
self.created_at = datetime.utcnow() |
|
|
if self.permissions is None: |
|
|
self.permissions = set() |
|
|
if self.metadata is None: |
|
|
self.metadata = {} |
|
|
|
|
|
def is_expired(self) -> bool: |
|
|
"""Check if key is expired""" |
|
|
if self.expires_at is None: |
|
|
return False |
|
|
return datetime.utcnow() > self.expires_at |
|
|
|
|
|
def is_valid(self) -> bool: |
|
|
"""Check if key is valid""" |
|
|
return self.is_active and not self.is_expired() |
|
|
|
|
|
|
|
|
class APIKeyManager: |
|
|
""" |
|
|
API Key Manager with secure key storage and validation |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.keys: Dict[str, APIKey] = {} |
|
|
self._load_keys_from_env() |
|
|
logger.info(f"API Key Manager initialized with {len(self.keys)} keys") |
|
|
|
|
|
def _load_keys_from_env(self): |
|
|
"""Load API keys from environment variables""" |
|
|
|
|
|
primary_key = os.getenv("MCP_API_KEY") |
|
|
if primary_key: |
|
|
key_id = "primary" |
|
|
key_hash = self._hash_key(primary_key) |
|
|
self.keys[key_hash] = APIKey( |
|
|
key_id=key_id, |
|
|
key_hash=key_hash, |
|
|
name="Primary API Key", |
|
|
is_active=True, |
|
|
permissions={"*"}, |
|
|
rate_limit=1000 |
|
|
) |
|
|
logger.info("Loaded primary API key from environment") |
|
|
|
|
|
|
|
|
additional_keys = os.getenv("MCP_API_KEYS", "") |
|
|
if additional_keys: |
|
|
for idx, key in enumerate(additional_keys.split(",")): |
|
|
key = key.strip() |
|
|
if key: |
|
|
key_id = f"key_{idx + 1}" |
|
|
key_hash = self._hash_key(key) |
|
|
self.keys[key_hash] = APIKey( |
|
|
key_id=key_id, |
|
|
key_hash=key_hash, |
|
|
name=f"API Key {idx + 1}", |
|
|
is_active=True, |
|
|
permissions={"*"}, |
|
|
rate_limit=100 |
|
|
) |
|
|
logger.info(f"Loaded {len(additional_keys.split(','))} additional API keys") |
|
|
|
|
|
@staticmethod |
|
|
def generate_api_key() -> str: |
|
|
""" |
|
|
Generate a secure API key |
|
|
Format: mcp_<32-char-hex> |
|
|
""" |
|
|
random_bytes = secrets.token_bytes(32) |
|
|
key_hex = random_bytes.hex() |
|
|
return f"mcp_{key_hex}" |
|
|
|
|
|
@staticmethod |
|
|
def _hash_key(key: str) -> str: |
|
|
"""Hash an API key using SHA-256""" |
|
|
return hashlib.sha256(key.encode()).hexdigest() |
|
|
|
|
|
def create_key( |
|
|
self, |
|
|
name: str, |
|
|
tenant_id: Optional[str] = None, |
|
|
expires_in_days: Optional[int] = None, |
|
|
permissions: Set[str] = None, |
|
|
rate_limit: int = 100 |
|
|
) -> Tuple[str, APIKey]: |
|
|
""" |
|
|
Create a new API key |
|
|
|
|
|
Returns: |
|
|
Tuple of (plain_key, api_key_object) |
|
|
""" |
|
|
plain_key = self.generate_api_key() |
|
|
key_hash = self._hash_key(plain_key) |
|
|
|
|
|
expires_at = None |
|
|
if expires_in_days: |
|
|
expires_at = datetime.utcnow() + timedelta(days=expires_in_days) |
|
|
|
|
|
api_key = APIKey( |
|
|
key_id=f"key_{len(self.keys) + 1}", |
|
|
key_hash=key_hash, |
|
|
name=name, |
|
|
tenant_id=tenant_id, |
|
|
expires_at=expires_at, |
|
|
permissions=permissions or {"*"}, |
|
|
rate_limit=rate_limit |
|
|
) |
|
|
|
|
|
self.keys[key_hash] = api_key |
|
|
logger.info(f"Created new API key: {api_key.key_id} for {name}") |
|
|
|
|
|
return plain_key, api_key |
|
|
|
|
|
def validate_key(self, plain_key: str) -> Optional[APIKey]: |
|
|
""" |
|
|
Validate an API key |
|
|
|
|
|
Returns: |
|
|
APIKey object if valid, None otherwise |
|
|
""" |
|
|
if not plain_key: |
|
|
return None |
|
|
|
|
|
key_hash = self._hash_key(plain_key) |
|
|
api_key = self.keys.get(key_hash) |
|
|
|
|
|
if not api_key: |
|
|
logger.warning("Invalid API key provided") |
|
|
return None |
|
|
|
|
|
if not api_key.is_valid(): |
|
|
logger.warning(f"Expired or inactive API key: {api_key.key_id}") |
|
|
return None |
|
|
|
|
|
return api_key |
|
|
|
|
|
def revoke_key(self, key_hash: str): |
|
|
"""Revoke an API key""" |
|
|
if key_hash in self.keys: |
|
|
self.keys[key_hash].is_active = False |
|
|
logger.info(f"Revoked API key: {self.keys[key_hash].key_id}") |
|
|
|
|
|
def list_keys(self) -> list[APIKey]: |
|
|
"""List all API keys""" |
|
|
return list(self.keys.values()) |
|
|
|
|
|
|
|
|
class APIKeyAuthMiddleware: |
|
|
""" |
|
|
aiohttp middleware for API key authentication |
|
|
""" |
|
|
|
|
|
def __init__(self, key_manager: APIKeyManager, exempt_paths: Set[str] = None): |
|
|
self.key_manager = key_manager |
|
|
self.exempt_paths = exempt_paths or {"/health", "/metrics"} |
|
|
logger.info("API Key Auth Middleware initialized") |
|
|
|
|
|
@web.middleware |
|
|
async def middleware(self, request: web.Request, handler): |
|
|
"""Middleware handler""" |
|
|
|
|
|
|
|
|
if request.path in self.exempt_paths: |
|
|
return await handler(request) |
|
|
|
|
|
|
|
|
api_key = self._extract_api_key(request) |
|
|
|
|
|
if not api_key: |
|
|
logger.warning(f"No API key provided for {request.path}") |
|
|
return web.json_response( |
|
|
{"error": "Authentication required", "message": "API key missing"}, |
|
|
status=401 |
|
|
) |
|
|
|
|
|
|
|
|
key_obj = self.key_manager.validate_key(api_key) |
|
|
|
|
|
if not key_obj: |
|
|
logger.warning(f"Invalid API key for {request.path}") |
|
|
return web.json_response( |
|
|
{"error": "Authentication failed", "message": "Invalid or expired API key"}, |
|
|
status=401 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
request["api_key"] = key_obj |
|
|
request["tenant_id"] = key_obj.tenant_id |
|
|
|
|
|
logger.debug(f"Authenticated request: {request.path} with key {key_obj.key_id}") |
|
|
|
|
|
return await handler(request) |
|
|
|
|
|
def _extract_api_key(self, request: web.Request) -> Optional[str]: |
|
|
""" |
|
|
Extract API key from request |
|
|
|
|
|
Supports: |
|
|
- X-API-Key header |
|
|
- Authorization: Bearer <key> header |
|
|
- api_key query parameter |
|
|
""" |
|
|
|
|
|
api_key = request.headers.get("X-API-Key") |
|
|
if api_key: |
|
|
return api_key |
|
|
|
|
|
|
|
|
auth_header = request.headers.get("Authorization") |
|
|
if auth_header and auth_header.startswith("Bearer "): |
|
|
return auth_header[7:] |
|
|
|
|
|
|
|
|
api_key = request.query.get("api_key") |
|
|
if api_key: |
|
|
logger.warning("API key provided via query parameter (insecure)") |
|
|
return api_key |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
class RequestSigningAuth: |
|
|
""" |
|
|
Request signing authentication using HMAC |
|
|
More secure than API keys alone |
|
|
""" |
|
|
|
|
|
def __init__(self, secret_key: Optional[str] = None): |
|
|
self.secret_key = secret_key or os.getenv("MCP_SECRET_KEY", "") |
|
|
if not self.secret_key: |
|
|
logger.warning("No secret key provided for request signing") |
|
|
|
|
|
def sign_request(self, method: str, path: str, body: str, timestamp: str) -> str: |
|
|
""" |
|
|
Sign a request using HMAC-SHA256 |
|
|
|
|
|
Args: |
|
|
method: HTTP method (GET, POST, etc.) |
|
|
path: Request path |
|
|
body: Request body (JSON string) |
|
|
timestamp: ISO timestamp |
|
|
|
|
|
Returns: |
|
|
HMAC signature (hex string) |
|
|
""" |
|
|
message = f"{method}|{path}|{body}|{timestamp}" |
|
|
signature = hmac.new( |
|
|
self.secret_key.encode(), |
|
|
message.encode(), |
|
|
hashlib.sha256 |
|
|
).hexdigest() |
|
|
return signature |
|
|
|
|
|
def verify_signature( |
|
|
self, |
|
|
method: str, |
|
|
path: str, |
|
|
body: str, |
|
|
timestamp: str, |
|
|
signature: str |
|
|
) -> bool: |
|
|
""" |
|
|
Verify request signature |
|
|
|
|
|
Returns: |
|
|
True if signature is valid, False otherwise |
|
|
""" |
|
|
|
|
|
try: |
|
|
request_time = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) |
|
|
time_diff = abs((datetime.utcnow() - request_time).total_seconds()) |
|
|
|
|
|
|
|
|
if time_diff > 300: |
|
|
logger.warning(f"Request timestamp too old: {time_diff}s") |
|
|
return False |
|
|
except Exception as e: |
|
|
logger.error(f"Invalid timestamp format: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
expected_signature = self.sign_request(method, path, body, timestamp) |
|
|
return hmac.compare_digest(expected_signature, signature) |
|
|
|
|
|
@web.middleware |
|
|
async def middleware(self, request: web.Request, handler): |
|
|
"""Middleware for request signing verification""" |
|
|
|
|
|
|
|
|
if request.path in {"/health", "/metrics"}: |
|
|
return await handler(request) |
|
|
|
|
|
|
|
|
signature = request.headers.get("X-Signature") |
|
|
timestamp = request.headers.get("X-Timestamp") |
|
|
|
|
|
if not signature or not timestamp: |
|
|
return web.json_response( |
|
|
{"error": "Missing signature or timestamp"}, |
|
|
status=401 |
|
|
) |
|
|
|
|
|
|
|
|
body = "" |
|
|
if request.can_read_body: |
|
|
body_bytes = await request.read() |
|
|
body = body_bytes.decode() |
|
|
|
|
|
|
|
|
if not self.verify_signature( |
|
|
request.method, |
|
|
request.path, |
|
|
body, |
|
|
timestamp, |
|
|
signature |
|
|
): |
|
|
logger.warning(f"Invalid signature for {request.path}") |
|
|
return web.json_response( |
|
|
{"error": "Invalid signature"}, |
|
|
status=401 |
|
|
) |
|
|
|
|
|
return await handler(request) |
|
|
|
|
|
|
|
|
|
|
|
_key_manager: Optional[APIKeyManager] = None |
|
|
|
|
|
|
|
|
def get_key_manager() -> APIKeyManager: |
|
|
"""Get or create the global API key manager""" |
|
|
global _key_manager |
|
|
if _key_manager is None: |
|
|
_key_manager = APIKeyManager() |
|
|
return _key_manager |
|
|
|