Spaces:
Sleeping
Sleeping
File size: 4,593 Bytes
45742a7 3963750 45742a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""
Authentication module for API key management.
Handles password hashing, API key generation (hashed), rate limiting, and validation.
"""
import secrets
import hashlib
import bcrypt
from datetime import datetime, date
from fastapi import Header, HTTPException, Depends
from pydantic import BaseModel, EmailStr
# --- Configuration ---
DEFAULT_RATE_LIMIT = 100 # Requests per day
# --- Pydantic Models ---
class UserSignup(BaseModel):
"""Request body for signup."""
email: EmailStr
password: str
class UserLogin(BaseModel):
"""Request body for login."""
email: EmailStr
password: str
class UserResponse(BaseModel):
"""Response after signup/login."""
email: str
api_key: str
message: str
class UsageResponse(BaseModel):
"""Response for usage endpoint."""
email: str
requests_today: int
rate_limit: int
remaining: int
total_requests: int
# --- Helper Functions ---
def hash_password(password: str) -> str:
"""Hash password using bcrypt."""
salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode('utf-8'), salt).decode('utf-8')
def verify_password(password: str, hashed: str) -> bool:
"""Verify password against hash."""
return bcrypt.checkpw(password.encode('utf-8'), hashed.encode('utf-8'))
def generate_api_key() -> tuple[str, str, str]:
"""
Generate a unique API key in OpenAI style: sk-live-xxxx.
Returns: (raw_key, key_hash, key_prefix)
- raw_key: The full key to show user ONCE
- key_hash: SHA-256 hash to store in DB
- key_prefix: First 12 chars for display (sk-live-abc1...)
"""
random_part = secrets.token_hex(24) # 48 character hex string
raw_key = f"sk-live-{random_part}"
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
key_prefix = raw_key[:16] + "..." # e.g., "sk-live-abc123..."
return raw_key, key_hash, key_prefix
def hash_api_key(api_key: str) -> str:
"""Hash an API key using SHA-256."""
return hashlib.sha256(api_key.encode()).hexdigest()
def create_user_document(email: str, password: str) -> tuple[dict, str]:
"""
Create a new user document for MongoDB.
Returns: (user_doc, raw_api_key)
"""
raw_key, key_hash, key_prefix = generate_api_key()
user_doc = {
"email": email,
"password_hash": hash_password(password),
"api_key_hash": key_hash,
"api_key_prefix": key_prefix,
"requests_today": 0,
"last_request_date": str(date.today()),
"total_requests": 0,
"rate_limit": DEFAULT_RATE_LIMIT,
"created_at": datetime.utcnow(),
"last_login": None
}
return user_doc, raw_key
async def check_rate_limit(user: dict, db) -> bool:
"""
Check and update rate limit for user.
Returns True if within limit, raises HTTPException if exceeded.
"""
today = str(date.today())
# Reset counter if new day
if user.get("last_request_date") != today:
await db.users.update_one(
{"_id": user["_id"]},
{"$set": {"requests_today": 0, "last_request_date": today}}
)
user["requests_today"] = 0
# Check if exceeded
rate_limit = user.get("rate_limit", DEFAULT_RATE_LIMIT)
if user.get("requests_today", 0) >= rate_limit:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Limit: {rate_limit} requests/day. Try again tomorrow."
)
# Increment counters
await db.users.update_one(
{"_id": user["_id"]},
{
"$inc": {"requests_today": 1, "total_requests": 1},
"$set": {"last_request_date": today}
}
)
return True
# --- API Key Validation Dependency ---
async def validate_api_key(x_api_key: str = Header(..., description="Your API key")):
"""
FastAPI dependency to validate API key.
Use this to protect endpoints.
"""
try:
from .database import get_database
except ImportError:
from database import get_database
db = get_database()
if db is None:
raise HTTPException(status_code=500, detail="Database not connected")
# Hash the incoming key and search
incoming_hash = hash_api_key(x_api_key)
user = await db.users.find_one({"api_key_hash": incoming_hash})
if not user:
raise HTTPException(
status_code=401,
detail="Invalid API key. Please login to get your API key."
)
# Check rate limit
await check_rate_limit(user, db)
return user
|