Sameric934's picture
Upload auth.py with huggingface_hub
1f968f2 verified
#!/usr/bin/env python3
"""
Authentication and security module for Hugging Face Spaces
Uses Hugging Face tokens for authentication
"""
import os
import logging
import secrets
import hashlib
import time
from typing import Optional, Dict, Any
from fastapi import Request, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
logger = logging.getLogger(__name__)
class HuggingFaceAuth:
"""Authentication manager using Hugging Face tokens"""
def __init__(self, hf_token: Optional[str] = None):
self.hf_token = hf_token or os.getenv("HF_TOKEN", "")
self.session_tokens: Dict[str, Dict[str, Any]] = {}
self.token_salt = secrets.token_hex(32)
def validate_token(self, token: str) -> bool:
"""Validate a Hugging Face token"""
if not token:
return False
# Simple validation - in production, would call Hugging Face API
# For now, check if token looks like a valid HF token
if token.startswith("hf_"):
return True
return False
def create_session(self, token: str) -> Optional[str]:
"""Create a new session for valid token"""
if not self.validate_token(token):
return None
session_id = secrets.token_urlsafe(32)
session_data = {
"token": token,
"created_at": time.time(),
"last_activity": time.time(),
"expires_at": time.time() + (24 * 3600) # 24 hours
}
# Hash the session ID for storage
session_hash = self._hash_session_id(session_id)
self.session_tokens[session_hash] = session_data
return session_id
def validate_session(self, session_id: str) -> bool:
"""Validate a session ID"""
if not session_id:
return False
session_hash = self._hash_session_id(session_id)
session_data = self.session_tokens.get(session_hash)
if not session_data:
return False
# Check expiration
if time.time() > session_data["expires_at"]:
del self.session_tokens[session_hash]
return False
# Update last activity
session_data["last_activity"] = time.time()
return True
def revoke_session(self, session_id: str) -> bool:
"""Revoke a session"""
session_hash = self._hash_session_id(session_id)
if session_hash in self.session_tokens:
del self.session_tokens[session_hash]
return True
return False
def _hash_session_id(self, session_id: str) -> str:
"""Hash session ID with salt for storage"""
return hashlib.sha256(
f"{session_id}{self.token_salt}".encode()
).hexdigest()
def get_rate_limit_key(self, identifier: str) -> str:
"""Get rate limit key for tracking"""
return f"rate_limit:{identifier}"
def check_rate_limit(self, identifier: str, limit: int = 100, window: int = 3600) -> bool:
"""Check if rate limit is exceeded"""
# Simple in-memory rate limiting
# In production, use Redis or similar
key = self.get_rate_limit_key(identifier)
# This is a simplified implementation
# Would track timestamps and counts in production
return True # Always allow for now
# HTTP Bearer authentication
security = HTTPBearer()
def get_auth_token(request: Request) -> Optional[str]:
"""Extract authentication token from request"""
# Check Authorization header
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header[7:]
# Check query parameter
token = request.query_params.get("token")
if token:
return token
# Check cookie
token = request.cookies.get("hf_token")
if token:
return token
return None
def require_auth(hf_token: str = ""):
"""Decorator to require authentication"""
auth = HuggingFaceAuth(hf_token)
async def decorator(request: Request):
token = get_auth_token(request)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
if not auth.validate_token(token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication token"
)
# Check rate limiting
if not auth.check_rate_limit(token):
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
return request
return decorator
# Global auth instance
_auth_instance: Optional[HuggingFaceAuth] = None
def get_auth() -> HuggingFaceAuth:
"""Get or create auth instance"""
global _auth_instance
if _auth_instance is None:
_auth_instance = HuggingFaceAuth()
return _auth_instance
def setup_auth(hf_token: str):
"""Set up authentication with Hugging Face token"""
global _auth_instance
_auth_instance = HuggingFaceAuth(hf_token)
# WebSocket authentication
async def authenticate_websocket(websocket, token: Optional[str] = None):
"""Authenticate WebSocket connection"""
if not token:
# Try to get token from query parameters
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return False
auth = get_auth()
if not auth.validate_token(token):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return False
return True
if __name__ == "__main__":
# Test the authentication module with dummy tokens only
import asyncio
# Test token validation with dummy tokens
auth = HuggingFaceAuth()
test_tokens = [
"hf_dummytoken123", # Valid format (starts with hf_)
"invalid_token", # Invalid format
"", # Empty
"test_token" # Invalid format
]
for token in test_tokens:
valid = auth.validate_token(token)
print(f"Token '{token[:10]}...': {valid}")
# Test session creation with dummy token
dummy_token = "hf_dummytoken456"
session_id = auth.create_session(dummy_token)
print(f"\nCreated session: {session_id[:20]}...")
if session_id:
valid = auth.validate_session(session_id)
print(f"Session validation: {valid}")
revoked = auth.revoke_session(session_id)
print(f"Session revoked: {revoked}")
valid_after_revoke = auth.validate_session(session_id)
print(f"Session validation after revoke: {valid_after_revoke}")
print("\n✅ Auth module test completed with dummy tokens only")