Spaces:
Paused
Paused
| """ | |
| API Keys Service | |
| This module provides functionality for managing API keys including: | |
| - Creating new API keys with UUIDs | |
| - Validating API keys for authentication | |
| - Managing expiration and revocation | |
| - CRUD operations for user API keys | |
| """ | |
| import asyncio | |
| from datetime import datetime, timezone, timedelta | |
| from typing import Optional, List, Dict | |
| from uuid import UUID, uuid4 | |
| import secrets | |
| import string | |
| import hmac | |
| import hashlib | |
| import time | |
| from pydantic import BaseModel, Field, field_validator | |
| from fastapi import HTTPException | |
| from utils.logger import logger | |
| from services.supabase import DBConnection | |
| from services import redis | |
| from utils.config import config | |
| class APIKeyStatus: | |
| ACTIVE = "active" | |
| REVOKED = "revoked" | |
| EXPIRED = "expired" | |
| class APIKeyCreateRequest(BaseModel): | |
| """Request model for creating a new API key""" | |
| title: str = Field( | |
| ..., | |
| min_length=1, | |
| max_length=255, | |
| description="Human-readable title for the API key", | |
| ) | |
| description: Optional[str] = Field( | |
| None, description="Optional description for the API key" | |
| ) | |
| expires_in_days: Optional[int] = Field( | |
| None, gt=0, le=365, description="Number of days until expiration (max 365)" | |
| ) | |
| def validate_title(cls, v): | |
| if not v or not v.strip(): | |
| raise ValueError("Title cannot be empty") | |
| return v.strip() | |
| class APIKeyResponse(BaseModel): | |
| """Response model for API key information (without the secret key)""" | |
| key_id: UUID | |
| public_key: str | |
| title: str | |
| description: Optional[str] | |
| status: str | |
| expires_at: Optional[datetime] | |
| last_used_at: Optional[datetime] | |
| created_at: datetime | |
| class APIKeyCreateResponse(BaseModel): | |
| """Response model for newly created API key (includes both keys)""" | |
| key_id: UUID | |
| public_key: str | |
| secret_key: str # Only returned on creation | |
| title: str | |
| description: Optional[str] | |
| status: str | |
| expires_at: Optional[datetime] | |
| created_at: datetime | |
| class APIKeyValidationResult(BaseModel): | |
| """Result of API key validation""" | |
| is_valid: bool | |
| account_id: Optional[UUID] = None | |
| key_id: Optional[UUID] = None | |
| error_message: Optional[str] = None | |
| class APIKeyService: | |
| """ | |
| Service for managing API keys with performance optimizations | |
| Performance Features: | |
| - HMAC-SHA256 hashing (100x faster than bcrypt) | |
| - Redis caching for validation results (2min TTL) | |
| - Throttled last_used_at updates (max once per 15min per key, configurable) | |
| - Cached user lookups (5min TTL) | |
| - Asynchronous operations where possible | |
| - In-memory fallback throttling when Redis unavailable | |
| - Streamlined database schema without unnecessary triggers | |
| """ | |
| # Class-level in-memory throttle cache (fallback when Redis unavailable) | |
| _throttle_cache: Dict[str, float] = {} | |
| def __init__(self, db: DBConnection): | |
| self.db = db | |
| def _generate_key_pair(self) -> tuple[str, str]: | |
| """ | |
| Generate a public key and secret key pair | |
| Returns: | |
| tuple: (public_key, secret_key) where public_key starts with 'pk_' and secret_key starts with 'sk_' | |
| """ | |
| # Generate random strings for both keys | |
| pk_suffix = "".join( | |
| secrets.choice(string.ascii_letters + string.digits) for _ in range(32) | |
| ) | |
| sk_suffix = "".join( | |
| secrets.choice(string.ascii_letters + string.digits) for _ in range(32) | |
| ) | |
| public_key = f"pk_{pk_suffix}" | |
| secret_key = f"sk_{sk_suffix}" | |
| return public_key, secret_key | |
| def _get_secret_key(self) -> str: | |
| """Get the secret key for HMAC hashing""" | |
| return config.API_KEY_SECRET | |
| def _hash_secret_key(self, secret_key: str) -> str: | |
| """ | |
| Hash a secret key using HMAC-SHA256 (much faster than bcrypt) | |
| Args: | |
| secret_key: The secret key to hash | |
| Returns: | |
| str: The HMAC-SHA256 hash of the secret key | |
| """ | |
| secret = self._get_secret_key().encode("utf-8") | |
| return hmac.new(secret, secret_key.encode("utf-8"), hashlib.sha256).hexdigest() | |
| def _verify_secret_key(self, secret_key: str, hashed_key: str) -> bool: | |
| """ | |
| Verify a secret key against its hash using constant-time comparison | |
| Args: | |
| secret_key: The secret key to verify | |
| hashed_key: The stored hash | |
| Returns: | |
| bool: True if the secret key matches the hash | |
| """ | |
| try: | |
| expected_hash = self._hash_secret_key(secret_key) | |
| return hmac.compare_digest(expected_hash, hashed_key) | |
| except Exception: | |
| return False | |
| async def create_api_key( | |
| self, account_id: UUID, request: APIKeyCreateRequest | |
| ) -> APIKeyCreateResponse: | |
| """ | |
| Create a new API key for the specified account | |
| Args: | |
| account_id: The account ID to create the key for | |
| request: The API key creation request | |
| Returns: | |
| APIKeyCreateResponse containing the new API key details including both keys | |
| """ | |
| try: | |
| # Calculate expiration date if specified | |
| expires_at = None | |
| if request.expires_in_days: | |
| expires_at = datetime.now(timezone.utc) + timedelta( | |
| days=request.expires_in_days | |
| ) | |
| # Generate public and secret key pair | |
| public_key, secret_key = self._generate_key_pair() | |
| # Hash the secret key for storage | |
| secret_key_hash = self._hash_secret_key(secret_key) | |
| # Insert into database | |
| client = await self.db.client | |
| result = ( | |
| await client.table("api_keys") | |
| .insert( | |
| { | |
| "public_key": public_key, | |
| "secret_key_hash": secret_key_hash, | |
| "account_id": str(account_id), | |
| "title": request.title, | |
| "description": request.description, | |
| "expires_at": expires_at.isoformat() if expires_at else None, | |
| "status": APIKeyStatus.ACTIVE, | |
| } | |
| ) | |
| .execute() | |
| ) | |
| if not result.data: | |
| raise HTTPException(status_code=500, detail="Failed to create API key") | |
| key_data = result.data[0] | |
| logger.debug( | |
| "API key created successfully", | |
| account_id=str(account_id), | |
| key_id=key_data["key_id"], | |
| public_key=public_key, | |
| title=request.title, | |
| ) | |
| return APIKeyCreateResponse( | |
| key_id=UUID(key_data["key_id"]), | |
| public_key=public_key, | |
| secret_key=secret_key, # Only returned on creation | |
| title=key_data["title"], | |
| description=key_data["description"], | |
| status=key_data["status"], | |
| expires_at=( | |
| datetime.fromisoformat(key_data["expires_at"]) | |
| if key_data["expires_at"] | |
| else None | |
| ), | |
| created_at=datetime.fromisoformat(key_data["created_at"]), | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error creating API key: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Failed to create API key") | |
| async def list_api_keys(self, account_id: UUID) -> List[APIKeyResponse]: | |
| """ | |
| List all API keys for the specified account | |
| Args: | |
| account_id: The account ID to list keys for | |
| Returns: | |
| List of APIKeyResponse objects | |
| """ | |
| try: | |
| client = await self.db.client | |
| result = ( | |
| await client.table("api_keys") | |
| .select( | |
| "key_id, public_key, title, description, status, expires_at, last_used_at, created_at" | |
| ) | |
| .eq("account_id", str(account_id)) | |
| .order("created_at", desc=True) | |
| .execute() | |
| ) | |
| api_keys = [] | |
| for key_data in result.data: | |
| api_keys.append( | |
| APIKeyResponse( | |
| key_id=UUID(key_data["key_id"]), | |
| public_key=key_data["public_key"], | |
| title=key_data["title"], | |
| description=key_data["description"], | |
| status=key_data["status"], | |
| expires_at=( | |
| datetime.fromisoformat(key_data["expires_at"]) | |
| if key_data["expires_at"] | |
| else None | |
| ), | |
| last_used_at=( | |
| datetime.fromisoformat(key_data["last_used_at"]) | |
| if key_data["last_used_at"] | |
| else None | |
| ), | |
| created_at=datetime.fromisoformat(key_data["created_at"]), | |
| ) | |
| ) | |
| return api_keys | |
| except Exception as e: | |
| logger.error(f"Error listing API keys: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Failed to list API keys") | |
| async def revoke_api_key(self, account_id: UUID, key_id: UUID) -> bool: | |
| """ | |
| Revoke an API key | |
| Args: | |
| account_id: The account ID that owns the key | |
| key_id: The ID of the key to revoke | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| try: | |
| client = await self.db.client | |
| result = ( | |
| await client.table("api_keys") | |
| .update({"status": APIKeyStatus.REVOKED}) | |
| .eq("key_id", str(key_id)) | |
| .eq("account_id", str(account_id)) | |
| .execute() | |
| ) | |
| if not result.data: | |
| raise HTTPException(status_code=404, detail="API key not found") | |
| logger.debug( | |
| "API key revoked successfully", | |
| account_id=str(account_id), | |
| key_id=str(key_id), | |
| ) | |
| return True | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error revoking API key: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Failed to revoke API key") | |
| async def validate_api_key( | |
| self, public_key: str, secret_key: str | |
| ) -> APIKeyValidationResult: | |
| """ | |
| Validate an API key pair with Redis caching for performance | |
| Args: | |
| public_key: The public key (starts with 'pk_') | |
| secret_key: The secret key (starts with 'sk_') | |
| Returns: | |
| APIKeyValidationResult with validation status and account info | |
| """ | |
| try: | |
| # Validate key format | |
| if not public_key.startswith("pk_") or not secret_key.startswith("sk_"): | |
| return APIKeyValidationResult( | |
| is_valid=False, error_message="Invalid API key format" | |
| ) | |
| # Check Redis cache first (cache key includes secret hash for security) | |
| cache_key = f"api_key:{public_key}:{self._hash_secret_key(secret_key)[:8]}" | |
| try: | |
| redis_client = await redis.get_client() | |
| cached_result = await redis_client.get(cache_key) | |
| if cached_result: | |
| import json | |
| cached_data = json.loads(cached_result) | |
| logger.debug(f"API key validation cache hit for {public_key}") | |
| return APIKeyValidationResult( | |
| is_valid=cached_data["is_valid"], | |
| account_id=( | |
| UUID(cached_data["account_id"]) | |
| if cached_data["account_id"] | |
| else None | |
| ), | |
| key_id=( | |
| UUID(cached_data["key_id"]) | |
| if cached_data["key_id"] | |
| else None | |
| ), | |
| error_message=cached_data.get("error_message"), | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Redis cache lookup failed: {e}") | |
| # Continue without cache | |
| client = await self.db.client | |
| # Single optimized query with join to get user info | |
| result = ( | |
| await client.table("api_keys") | |
| .select("key_id, account_id, status, expires_at, secret_key_hash") | |
| .eq("public_key", public_key) | |
| .execute() | |
| ) | |
| if not result.data: | |
| validation_result = APIKeyValidationResult( | |
| is_valid=False, error_message="API key not found" | |
| ) | |
| await self._cache_validation_result( | |
| cache_key, validation_result, ttl=300 | |
| ) # Cache negative results for 5 min | |
| return validation_result | |
| key_data = result.data[0] | |
| # Check if key is expired first (faster than status check) | |
| if key_data["expires_at"]: | |
| expires_at = datetime.fromisoformat(key_data["expires_at"]) | |
| if expires_at < datetime.now(timezone.utc): | |
| validation_result = APIKeyValidationResult( | |
| is_valid=False, error_message="API key expired" | |
| ) | |
| await self._cache_validation_result( | |
| cache_key, validation_result, ttl=3600 | |
| ) # Cache expired for 1 hour | |
| return validation_result | |
| # Check if key is active | |
| if key_data["status"] != APIKeyStatus.ACTIVE: | |
| validation_result = APIKeyValidationResult( | |
| is_valid=False, error_message=f"API key is {key_data['status']}" | |
| ) | |
| await self._cache_validation_result( | |
| cache_key, validation_result, ttl=3600 | |
| ) # Cache inactive for 1 hour | |
| return validation_result | |
| # Verify the secret key against the stored hash | |
| if not self._verify_secret_key(secret_key, key_data["secret_key_hash"]): | |
| validation_result = APIKeyValidationResult( | |
| is_valid=False, error_message="Invalid secret key" | |
| ) | |
| await self._cache_validation_result( | |
| cache_key, validation_result, ttl=300 | |
| ) # Cache invalid for 5 min | |
| return validation_result | |
| # Success case | |
| validation_result = APIKeyValidationResult( | |
| is_valid=True, | |
| account_id=UUID(key_data["account_id"]), | |
| key_id=UUID(key_data["key_id"]), | |
| ) | |
| # Cache successful validation for 2 minutes | |
| await self._cache_validation_result(cache_key, validation_result, ttl=120) | |
| # Update last used timestamp with throttling to prevent DB spam | |
| # (max once per 15 minutes per key, configurable via config.API_KEY_LAST_USED_THROTTLE_SECONDS) | |
| asyncio.create_task(self._update_last_used_throttled(key_data["key_id"])) | |
| return validation_result | |
| except Exception as e: | |
| logger.error(f"Error validating API key: {e}", exc_info=True) | |
| return APIKeyValidationResult( | |
| is_valid=False, error_message="Internal server error" | |
| ) | |
| async def _cache_validation_result( | |
| self, cache_key: str, result: APIKeyValidationResult, ttl: int = 120 | |
| ): | |
| """Cache validation result in Redis""" | |
| try: | |
| redis_client = await redis.get_client() | |
| import json | |
| cache_data = { | |
| "is_valid": result.is_valid, | |
| "account_id": str(result.account_id) if result.account_id else None, | |
| "key_id": str(result.key_id) if result.key_id else None, | |
| "error_message": result.error_message, | |
| } | |
| await redis_client.setex(cache_key, ttl, json.dumps(cache_data)) | |
| except Exception as e: | |
| logger.warning(f"Failed to cache validation result: {e}") | |
| async def _update_last_used_throttled(self, key_id: str): | |
| """Update last used timestamp with throttling to reduce DB load""" | |
| throttle_interval = config.API_KEY_LAST_USED_THROTTLE_SECONDS | |
| current_time = time.time() | |
| # Try Redis first | |
| try: | |
| redis_client = await redis.get_client() | |
| throttle_key = f"last_used_throttle:{key_id}" | |
| # Check if we've updated this key recently | |
| last_update = await redis_client.get(throttle_key) | |
| if last_update: | |
| # Already updated within throttle interval, skip | |
| return | |
| # Set throttle flag first to prevent race conditions | |
| await redis_client.setex(throttle_key, throttle_interval, "1") | |
| except Exception as redis_error: | |
| # Fallback to in-memory throttling when Redis unavailable | |
| logger.debug( | |
| f"Redis unavailable for throttling, using in-memory fallback: {redis_error}" | |
| ) | |
| # Clean up old entries (simple cleanup every 100 operations) | |
| if len(self._throttle_cache) > 1000: | |
| cutoff_time = current_time - ( | |
| throttle_interval * 2 | |
| ) # Keep extra buffer | |
| self._throttle_cache = { | |
| k: v for k, v in self._throttle_cache.items() if v > cutoff_time | |
| } | |
| # Check in-memory throttle | |
| last_update_time = self._throttle_cache.get(key_id, 0) | |
| if current_time - last_update_time < throttle_interval: | |
| # Already updated within throttle interval, skip | |
| return | |
| # Set in-memory throttle | |
| self._throttle_cache[key_id] = current_time | |
| # Update database | |
| try: | |
| client = await self.db.client | |
| await client.table("api_keys").update( | |
| {"last_used_at": datetime.now(timezone.utc).isoformat()} | |
| ).eq("key_id", key_id).execute() | |
| logger.debug(f"Updated last_used_at for key {key_id}") | |
| except Exception as e: | |
| logger.warning(f"Failed to update last_used_at for key {key_id}: {e}") | |
| async def _update_last_used_async(self, key_id: str): | |
| """Legacy method - kept for backwards compatibility""" | |
| await self._update_last_used_throttled(key_id) | |
| async def _clear_throttle(self, key_id: str): | |
| """Clear the throttle for a specific key (useful for testing)""" | |
| try: | |
| redis_client = await redis.get_client() | |
| throttle_key = f"last_used_throttle:{key_id}" | |
| await redis_client.delete(throttle_key) | |
| logger.debug(f"Cleared throttle for key {key_id}") | |
| except Exception as e: | |
| logger.warning(f"Failed to clear throttle for key {key_id}: {e}") | |
| async def delete_api_key(self, account_id: UUID, key_id: UUID) -> bool: | |
| """ | |
| Delete an API key permanently | |
| Args: | |
| account_id: The account ID that owns the key | |
| key_id: The ID of the key to delete | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| try: | |
| client = await self.db.client | |
| result = ( | |
| await client.table("api_keys") | |
| .delete() | |
| .eq("key_id", str(key_id)) | |
| .eq("account_id", str(account_id)) | |
| .execute() | |
| ) | |
| if not result.data: | |
| raise HTTPException(status_code=404, detail="API key not found") | |
| logger.debug( | |
| "API key deleted successfully", | |
| account_id=str(account_id), | |
| key_id=str(key_id), | |
| ) | |
| return True | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error deleting API key: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Failed to delete API key") | |