Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
File size: 4,229 Bytes
61d29fc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """
Authentication utilities - JWT tokens, password hashing, OAuth helpers
"""
import os
import secrets
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from dotenv import load_dotenv
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from api.database import get_db
from api.models import User
# Load environment variables
load_dotenv()
# Security configuration
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# HTTP Bearer token scheme
security = HTTPBearer(auto_error=False)
def hash_password(password: str) -> str:
"""Hash a password using bcrypt"""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash"""
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT access token
Args:
data: Payload to encode (usually {"sub": user_id})
expires_delta: Optional custom expiration time
Returns:
JWT token string
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def decode_access_token(token: str) -> Dict[str, Any]:
"""
Decode and validate a JWT token
Args:
token: JWT token string
Returns:
Decoded payload
Raises:
HTTPException: If token is invalid or expired
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db)
) -> Optional[User]:
"""
Get current authenticated user from JWT token
Usage:
@app.get("/protected")
def protected_route(user: User = Depends(get_current_user)):
return {"message": f"Hello {user.email}"}
Returns:
User object if authenticated, None if optional auth
"""
if not credentials:
return None
token = credentials.credentials
payload = decode_access_token(token)
user_id: int = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials"
)
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found"
)
# Update last login
user.last_login = datetime.utcnow()
db.commit()
return user
def require_auth(user: User = Depends(get_current_user)) -> User:
"""
Require authentication (raises 401 if not authenticated)
Usage:
@app.get("/protected")
def protected_route(user: User = Depends(require_auth)):
return {"message": f"Hello {user.email}"}
"""
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def generate_state_token() -> str:
"""Generate a secure random state token for OAuth CSRF protection"""
return secrets.token_urlsafe(32)
|