Spaces:
Paused
Paused
| from fastapi import HTTPException, Header, Depends | |
| from fastapi.security import APIKeyHeader | |
| from typing import Optional | |
| from config import API_KEY, HUGGINGFACE_API_KEY, HUGGINGFACE # Import API_KEY, HUGGINGFACE_API_KEY, HUGGINGFACE | |
| import os | |
| import json | |
| import base64 | |
| # Function to validate API key (moved from config.py) | |
| def validate_api_key(api_key_to_validate: str) -> bool: | |
| """ | |
| Validate the provided API key against the configured key. | |
| """ | |
| if not API_KEY: # API_KEY is imported from config | |
| # If no API key is configured, authentication is disabled (or treat as invalid) | |
| # Depending on desired behavior, for now, let's assume if API_KEY is not set, all keys are invalid unless it's an empty string match | |
| return False # Or True if you want to disable auth when API_KEY is not set | |
| return api_key_to_validate == API_KEY | |
| # API Key security scheme | |
| api_key_header = APIKeyHeader(name="Authorization", auto_error=False) | |
| # Dependency for API key validation | |
| async def get_api_key( | |
| authorization: Optional[str] = Header(None), | |
| x_ip_token: Optional[str] = Header(None, alias="x-ip-token") | |
| ): | |
| # Check if Hugging Face auth is enabled | |
| if HUGGINGFACE: # Use HUGGINGFACE from config | |
| if x_ip_token is None: | |
| raise HTTPException( | |
| status_code=401, # Unauthorised - because x-ip-token is missing | |
| detail="Missing x-ip-token header. This header is required for Hugging Face authentication." | |
| ) | |
| try: | |
| # Decode JWT payload | |
| parts = x_ip_token.split('.') | |
| if len(parts) < 2: | |
| raise ValueError("Invalid JWT format: Not enough parts to extract payload.") | |
| payload_encoded = parts[1] | |
| # Add padding if necessary, as Python's base64.urlsafe_b64decode requires it | |
| payload_encoded += '=' * (-len(payload_encoded) % 4) | |
| decoded_payload_bytes = base64.urlsafe_b64decode(payload_encoded) | |
| payload = json.loads(decoded_payload_bytes.decode('utf-8')) | |
| except ValueError as ve: | |
| # Log server-side for debugging, but return a generic client error | |
| print(f"ValueError processing x-ip-token: {ve}") | |
| raise HTTPException(status_code=400, detail=f"Invalid JWT format in x-ip-token: {str(ve)}") | |
| except (json.JSONDecodeError, base64.binascii.Error, UnicodeDecodeError) as e: | |
| print(f"Error decoding/parsing x-ip-token payload: {e}") | |
| raise HTTPException(status_code=400, detail=f"Malformed x-ip-token payload: {str(e)}") | |
| except Exception as e: # Catch any other unexpected errors during token processing | |
| print(f"Unexpected error processing x-ip-token: {e}") | |
| raise HTTPException(status_code=500, detail="Internal error processing x-ip-token.") | |
| error_in_token = payload.get("error") | |
| if error_in_token == "InvalidAccessToken": | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Access denied: x-ip-token indicates 'InvalidAccessToken'." | |
| ) | |
| elif error_in_token is None: # JSON 'null' is Python's None | |
| # If error is null, auth is successful. Now check if HUGGINGFACE_API_KEY is configured. | |
| print(f"HuggingFace authentication successful via x-ip-token (error field was null).") | |
| return HUGGINGFACE_API_KEY # Return the configured HUGGINGFACE_API_KEY | |
| else: | |
| # Any other non-null, non-"InvalidAccessToken" value in 'error' field | |
| raise HTTPException( | |
| status_code=403, | |
| detail=f"Access denied: x-ip-token indicates an unhandled error: '{error_in_token}'." | |
| ) | |
| else: | |
| # Fallback to Bearer token authentication if HUGGINGFACE env var is not "true" | |
| if authorization is None: | |
| detail_message = "Missing API key. Please include 'Authorization: Bearer YOUR_API_KEY' header." | |
| # Optionally, provide a hint if the HUGGINGFACE env var exists but is not "true" | |
| if os.getenv("HUGGINGFACE") is not None: # Check for existence, not value | |
| detail_message += " (Note: HUGGINGFACE mode with x-ip-token is not currently active)." | |
| raise HTTPException( | |
| status_code=401, | |
| detail=detail_message | |
| ) | |
| # Check if the header starts with "Bearer " | |
| if not authorization.startswith("Bearer "): | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid API key format. Use 'Authorization: Bearer YOUR_API_KEY'" | |
| ) | |
| # Extract the API key | |
| api_key = authorization.replace("Bearer ", "") | |
| # Validate the API key | |
| if not validate_api_key(api_key): # Call local validate_api_key | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid API key" | |
| ) | |
| return api_key |