Spaces:
Paused
Paused
Commit
·
5b8c4f9
1
Parent(s):
c9e3eb0
huggginface auth fix and fixed model list
Browse files- app/auth.py +90 -26
- app/config.py +4 -0
- app/routes/models_api.py +7 -2
app/auth.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
| 1 |
from fastapi import HTTPException, Header, Depends
|
| 2 |
from fastapi.security import APIKeyHeader
|
| 3 |
from typing import Optional
|
| 4 |
-
from config import API_KEY # Import API_KEY
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Function to validate API key (moved from config.py)
|
| 7 |
def validate_api_key(api_key_to_validate: str) -> bool:
|
|
@@ -18,28 +21,89 @@ def validate_api_key(api_key_to_validate: str) -> bool:
|
|
| 18 |
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
| 19 |
|
| 20 |
# Dependency for API key validation
|
| 21 |
-
async def get_api_key(
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import HTTPException, Header, Depends
|
| 2 |
from fastapi.security import APIKeyHeader
|
| 3 |
from typing import Optional
|
| 4 |
+
from config import API_KEY, HUGGINGFACE_API_KEY, HUGGINGFACE # Import API_KEY, HUGGINGFACE_API_KEY, HUGGINGFACE
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import base64
|
| 8 |
|
| 9 |
# Function to validate API key (moved from config.py)
|
| 10 |
def validate_api_key(api_key_to_validate: str) -> bool:
|
|
|
|
| 21 |
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
| 22 |
|
| 23 |
# Dependency for API key validation
|
| 24 |
+
async def get_api_key(
|
| 25 |
+
authorization: Optional[str] = Header(None),
|
| 26 |
+
x_ip_token: Optional[str] = Header(None, alias="x-ip-token")
|
| 27 |
+
):
|
| 28 |
+
# Check if Hugging Face auth is enabled
|
| 29 |
+
if HUGGINGFACE: # Use HUGGINGFACE from config
|
| 30 |
+
if x_ip_token is None:
|
| 31 |
+
raise HTTPException(
|
| 32 |
+
status_code=401, # Unauthorised - because x-ip-token is missing
|
| 33 |
+
detail="Missing x-ip-token header. This header is required for Hugging Face authentication."
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
# Decode JWT payload
|
| 38 |
+
parts = x_ip_token.split('.')
|
| 39 |
+
if len(parts) < 2:
|
| 40 |
+
raise ValueError("Invalid JWT format: Not enough parts to extract payload.")
|
| 41 |
+
payload_encoded = parts[1]
|
| 42 |
+
# Add padding if necessary, as Python's base64.urlsafe_b64decode requires it
|
| 43 |
+
payload_encoded += '=' * (-len(payload_encoded) % 4)
|
| 44 |
+
decoded_payload_bytes = base64.urlsafe_b64decode(payload_encoded)
|
| 45 |
+
payload = json.loads(decoded_payload_bytes.decode('utf-8'))
|
| 46 |
+
except ValueError as ve:
|
| 47 |
+
# Log server-side for debugging, but return a generic client error
|
| 48 |
+
print(f"ValueError processing x-ip-token: {ve}")
|
| 49 |
+
raise HTTPException(status_code=400, detail=f"Invalid JWT format in x-ip-token: {str(ve)}")
|
| 50 |
+
except (json.JSONDecodeError, base64.binascii.Error, UnicodeDecodeError) as e:
|
| 51 |
+
print(f"Error decoding/parsing x-ip-token payload: {e}")
|
| 52 |
+
raise HTTPException(status_code=400, detail=f"Malformed x-ip-token payload: {str(e)}")
|
| 53 |
+
except Exception as e: # Catch any other unexpected errors during token processing
|
| 54 |
+
print(f"Unexpected error processing x-ip-token: {e}")
|
| 55 |
+
raise HTTPException(status_code=500, detail="Internal error processing x-ip-token.")
|
| 56 |
+
|
| 57 |
+
error_in_token = payload.get("error")
|
| 58 |
+
|
| 59 |
+
if error_in_token == "InvalidAccessToken":
|
| 60 |
+
raise HTTPException(
|
| 61 |
+
status_code=403,
|
| 62 |
+
detail="Access denied: x-ip-token indicates 'InvalidAccessToken'."
|
| 63 |
+
)
|
| 64 |
+
elif error_in_token is None: # JSON 'null' is Python's None
|
| 65 |
+
# If error is null, auth is successful. Now check if HUGGINGFACE_API_KEY is configured.
|
| 66 |
+
if not HUGGINGFACE_API_KEY: # Check if the key is empty or not set.
|
| 67 |
+
print("Security configuration error: HUGGINGFACE_API_KEY is not configured, but HUGGINGFACE mode is active and x-ip-token is valid.")
|
| 68 |
+
raise HTTPException(
|
| 69 |
+
status_code=500,
|
| 70 |
+
detail="Service security configuration incomplete: HuggingFace API Key not set."
|
| 71 |
+
)
|
| 72 |
+
print(f"HuggingFace authentication successful via x-ip-token (error field was null).")
|
| 73 |
+
return HUGGINGFACE_API_KEY # Return the configured HUGGINGFACE_API_KEY
|
| 74 |
+
else:
|
| 75 |
+
# Any other non-null, non-"InvalidAccessToken" value in 'error' field
|
| 76 |
+
raise HTTPException(
|
| 77 |
+
status_code=403,
|
| 78 |
+
detail=f"Access denied: x-ip-token indicates an unhandled error: '{error_in_token}'."
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
# Fallback to Bearer token authentication if HUGGINGFACE env var is not "true"
|
| 82 |
+
if authorization is None:
|
| 83 |
+
detail_message = "Missing API key. Please include 'Authorization: Bearer YOUR_API_KEY' header."
|
| 84 |
+
# Optionally, provide a hint if the HUGGINGFACE env var exists but is not "true"
|
| 85 |
+
if os.getenv("HUGGINGFACE") is not None: # Check for existence, not value
|
| 86 |
+
detail_message += " (Note: HUGGINGFACE mode with x-ip-token is not currently active)."
|
| 87 |
+
raise HTTPException(
|
| 88 |
+
status_code=401,
|
| 89 |
+
detail=detail_message
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Check if the header starts with "Bearer "
|
| 93 |
+
if not authorization.startswith("Bearer "):
|
| 94 |
+
raise HTTPException(
|
| 95 |
+
status_code=401,
|
| 96 |
+
detail="Invalid API key format. Use 'Authorization: Bearer YOUR_API_KEY'"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Extract the API key
|
| 100 |
+
api_key = authorization.replace("Bearer ", "")
|
| 101 |
+
|
| 102 |
+
# Validate the API key
|
| 103 |
+
if not validate_api_key(api_key): # Call local validate_api_key
|
| 104 |
+
raise HTTPException(
|
| 105 |
+
status_code=401,
|
| 106 |
+
detail="Invalid API key"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return api_key
|
app/config.py
CHANGED
|
@@ -6,6 +6,10 @@ DEFAULT_PASSWORD = "123456"
|
|
| 6 |
# Get password from environment variable or use default
|
| 7 |
API_KEY = os.environ.get("API_KEY", DEFAULT_PASSWORD)
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
# Directory for service account credential files
|
| 10 |
CREDENTIALS_DIR = os.environ.get("CREDENTIALS_DIR", "/app/credentials")
|
| 11 |
|
|
|
|
| 6 |
# Get password from environment variable or use default
|
| 7 |
API_KEY = os.environ.get("API_KEY", DEFAULT_PASSWORD)
|
| 8 |
|
| 9 |
+
# HuggingFace Authentication Settings
|
| 10 |
+
HUGGINGFACE = os.environ.get("HUGGINGFACE", "false").lower() == "true"
|
| 11 |
+
HUGGINGFACE_API_KEY = os.environ.get("HUGGINGFACE_API_KEY", "") # Default to empty string, auth logic will verify if HF_MODE is true and this key is needed
|
| 12 |
+
|
| 13 |
# Directory for service account credential files
|
| 14 |
CREDENTIALS_DIR = os.environ.get("CREDENTIALS_DIR", "/app/credentials")
|
| 15 |
|
app/routes/models_api.py
CHANGED
|
@@ -25,6 +25,7 @@ async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_k
|
|
| 25 |
raw_express_models = await get_vertex_express_models()
|
| 26 |
|
| 27 |
candidate_model_ids = set()
|
|
|
|
| 28 |
|
| 29 |
if has_express_key:
|
| 30 |
candidate_model_ids.update(raw_express_models)
|
|
@@ -57,8 +58,12 @@ async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_k
|
|
| 57 |
for original_model_id in sorted(list(all_model_ids)):
|
| 58 |
current_display_prefix = ""
|
| 59 |
# Only add PAY_PREFIX if the model is not already an EXPRESS model (which has its own prefix)
|
| 60 |
-
if
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
current_display_prefix = PAY_PREFIX
|
| 63 |
|
| 64 |
base_display_id = f"{current_display_prefix}{original_model_id}"
|
|
|
|
| 25 |
raw_express_models = await get_vertex_express_models()
|
| 26 |
|
| 27 |
candidate_model_ids = set()
|
| 28 |
+
raw_vertex_models_set = set(raw_vertex_models) # For checking origin during prefixing
|
| 29 |
|
| 30 |
if has_express_key:
|
| 31 |
candidate_model_ids.update(raw_express_models)
|
|
|
|
| 58 |
for original_model_id in sorted(list(all_model_ids)):
|
| 59 |
current_display_prefix = ""
|
| 60 |
# Only add PAY_PREFIX if the model is not already an EXPRESS model (which has its own prefix)
|
| 61 |
+
# Apply PAY_PREFIX if SA creds are present, it's a model from raw_vertex_models,
|
| 62 |
+
# it's not experimental, and not already an EXPRESS model.
|
| 63 |
+
if has_sa_creds and \
|
| 64 |
+
original_model_id in raw_vertex_models_set and \
|
| 65 |
+
EXPERIMENTAL_MARKER not in original_model_id and \
|
| 66 |
+
not original_model_id.startswith("[EXPRESS]"):
|
| 67 |
current_display_prefix = PAY_PREFIX
|
| 68 |
|
| 69 |
base_display_id = f"{current_display_prefix}{original_model_id}"
|