Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Authentication and security module for Hugging Face Spaces | |
| Uses Hugging Face tokens for authentication | |
| """ | |
| import os | |
| import logging | |
| import secrets | |
| import hashlib | |
| import time | |
| from typing import Optional, Dict, Any | |
| from fastapi import Request, HTTPException, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| logger = logging.getLogger(__name__) | |
| class HuggingFaceAuth: | |
| """Authentication manager using Hugging Face tokens""" | |
| def __init__(self, hf_token: Optional[str] = None): | |
| self.hf_token = hf_token or os.getenv("HF_TOKEN", "") | |
| self.session_tokens: Dict[str, Dict[str, Any]] = {} | |
| self.token_salt = secrets.token_hex(32) | |
| def validate_token(self, token: str) -> bool: | |
| """Validate a Hugging Face token""" | |
| if not token: | |
| return False | |
| # Simple validation - in production, would call Hugging Face API | |
| # For now, check if token looks like a valid HF token | |
| if token.startswith("hf_"): | |
| return True | |
| return False | |
| def create_session(self, token: str) -> Optional[str]: | |
| """Create a new session for valid token""" | |
| if not self.validate_token(token): | |
| return None | |
| session_id = secrets.token_urlsafe(32) | |
| session_data = { | |
| "token": token, | |
| "created_at": time.time(), | |
| "last_activity": time.time(), | |
| "expires_at": time.time() + (24 * 3600) # 24 hours | |
| } | |
| # Hash the session ID for storage | |
| session_hash = self._hash_session_id(session_id) | |
| self.session_tokens[session_hash] = session_data | |
| return session_id | |
| def validate_session(self, session_id: str) -> bool: | |
| """Validate a session ID""" | |
| if not session_id: | |
| return False | |
| session_hash = self._hash_session_id(session_id) | |
| session_data = self.session_tokens.get(session_hash) | |
| if not session_data: | |
| return False | |
| # Check expiration | |
| if time.time() > session_data["expires_at"]: | |
| del self.session_tokens[session_hash] | |
| return False | |
| # Update last activity | |
| session_data["last_activity"] = time.time() | |
| return True | |
| def revoke_session(self, session_id: str) -> bool: | |
| """Revoke a session""" | |
| session_hash = self._hash_session_id(session_id) | |
| if session_hash in self.session_tokens: | |
| del self.session_tokens[session_hash] | |
| return True | |
| return False | |
| def _hash_session_id(self, session_id: str) -> str: | |
| """Hash session ID with salt for storage""" | |
| return hashlib.sha256( | |
| f"{session_id}{self.token_salt}".encode() | |
| ).hexdigest() | |
| def get_rate_limit_key(self, identifier: str) -> str: | |
| """Get rate limit key for tracking""" | |
| return f"rate_limit:{identifier}" | |
| def check_rate_limit(self, identifier: str, limit: int = 100, window: int = 3600) -> bool: | |
| """Check if rate limit is exceeded""" | |
| # Simple in-memory rate limiting | |
| # In production, use Redis or similar | |
| key = self.get_rate_limit_key(identifier) | |
| # This is a simplified implementation | |
| # Would track timestamps and counts in production | |
| return True # Always allow for now | |
| # HTTP Bearer authentication | |
| security = HTTPBearer() | |
| def get_auth_token(request: Request) -> Optional[str]: | |
| """Extract authentication token from request""" | |
| # Check Authorization header | |
| auth_header = request.headers.get("Authorization") | |
| if auth_header and auth_header.startswith("Bearer "): | |
| return auth_header[7:] | |
| # Check query parameter | |
| token = request.query_params.get("token") | |
| if token: | |
| return token | |
| # Check cookie | |
| token = request.cookies.get("hf_token") | |
| if token: | |
| return token | |
| return None | |
| def require_auth(hf_token: str = ""): | |
| """Decorator to require authentication""" | |
| auth = HuggingFaceAuth(hf_token) | |
| async def decorator(request: Request): | |
| token = get_auth_token(request) | |
| if not token: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Authentication required" | |
| ) | |
| if not auth.validate_token(token): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid authentication token" | |
| ) | |
| # Check rate limiting | |
| if not auth.check_rate_limit(token): | |
| raise HTTPException( | |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
| detail="Rate limit exceeded" | |
| ) | |
| return request | |
| return decorator | |
| # Global auth instance | |
| _auth_instance: Optional[HuggingFaceAuth] = None | |
| def get_auth() -> HuggingFaceAuth: | |
| """Get or create auth instance""" | |
| global _auth_instance | |
| if _auth_instance is None: | |
| _auth_instance = HuggingFaceAuth() | |
| return _auth_instance | |
| def setup_auth(hf_token: str): | |
| """Set up authentication with Hugging Face token""" | |
| global _auth_instance | |
| _auth_instance = HuggingFaceAuth(hf_token) | |
| # WebSocket authentication | |
| async def authenticate_websocket(websocket, token: Optional[str] = None): | |
| """Authenticate WebSocket connection""" | |
| if not token: | |
| # Try to get token from query parameters | |
| token = websocket.query_params.get("token") | |
| if not token: | |
| await websocket.close(code=status.WS_1008_POLICY_VIOLATION) | |
| return False | |
| auth = get_auth() | |
| if not auth.validate_token(token): | |
| await websocket.close(code=status.WS_1008_POLICY_VIOLATION) | |
| return False | |
| return True | |
| if __name__ == "__main__": | |
| # Test the authentication module with dummy tokens only | |
| import asyncio | |
| # Test token validation with dummy tokens | |
| auth = HuggingFaceAuth() | |
| test_tokens = [ | |
| "hf_dummytoken123", # Valid format (starts with hf_) | |
| "invalid_token", # Invalid format | |
| "", # Empty | |
| "test_token" # Invalid format | |
| ] | |
| for token in test_tokens: | |
| valid = auth.validate_token(token) | |
| print(f"Token '{token[:10]}...': {valid}") | |
| # Test session creation with dummy token | |
| dummy_token = "hf_dummytoken456" | |
| session_id = auth.create_session(dummy_token) | |
| print(f"\nCreated session: {session_id[:20]}...") | |
| if session_id: | |
| valid = auth.validate_session(session_id) | |
| print(f"Session validation: {valid}") | |
| revoked = auth.revoke_session(session_id) | |
| print(f"Session revoked: {revoked}") | |
| valid_after_revoke = auth.validate_session(session_id) | |
| print(f"Session validation after revoke: {valid_after_revoke}") | |
| print("\n✅ Auth module test completed with dummy tokens only") |