ShreyasGosavi's picture
Upload 37 files
53bec59 verified
"""
Authentication and Authorization
JWT tokens, API keys, password hashing
"""
import secrets
from datetime import datetime, timedelta
from typing import Optional, Union
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, APIKeyHeader
from sqlalchemy.orm import Session
from src.core.config import settings
from src.core.exceptions import AuthenticationError, AuthorizationError
from src.db.models import User, APIKey, get_db
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Security schemes
bearer_scheme = HTTPBearer()
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify password against hash"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate password hash"""
return pwd_context.hash(password)
def create_access_token(
data: dict,
expires_delta: Optional[timedelta] = None
) -> str:
"""Create JWT access token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def create_refresh_token(
data: dict,
expires_delta: Optional[timedelta] = None
) -> str:
"""Create JWT refresh token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def decode_token(token: str) -> dict:
"""Decode and validate JWT token"""
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
return payload
except JWTError:
raise AuthenticationError("Invalid or expired token")
def generate_api_key() -> str:
"""Generate secure API key"""
return secrets.token_urlsafe(32)
# Dependency: Get current user from JWT token
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Security(bearer_scheme),
db: Session = Depends(get_db)
) -> User:
"""Get current authenticated user from JWT token"""
try:
token = credentials.credentials
payload = decode_token(token)
user_id: int = payload.get("sub")
if user_id is None:
raise AuthenticationError("Invalid token payload")
except JWTError:
raise AuthenticationError("Could not validate credentials")
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise AuthenticationError("User not found")
if not user.is_active:
raise AuthenticationError("User account is inactive")
return user
# Dependency: Get current user from API key
async def get_current_user_from_api_key(
api_key: Optional[str] = Security(api_key_header),
db: Session = Depends(get_db)
) -> Optional[User]:
"""Get current user from API key"""
if not api_key:
return None
# Find API key in database
api_key_obj = db.query(APIKey).filter(
APIKey.key == api_key,
APIKey.is_active == True
).first()
if not api_key_obj:
raise AuthenticationError("Invalid API key")
# Check expiration
if api_key_obj.expires_at and api_key_obj.expires_at < datetime.utcnow():
raise AuthenticationError("API key has expired")
# Update last used timestamp
api_key_obj.last_used_at = datetime.utcnow()
db.commit()
# Get user
user = db.query(User).filter(User.id == api_key_obj.user_id).first()
if not user or not user.is_active:
raise AuthenticationError("User not found or inactive")
return user
# Dependency: Get current user (try JWT first, then API key)
async def get_current_user_flexible(
bearer: Optional[HTTPAuthorizationCredentials] = Security(bearer_scheme, auto_error=False),
api_key: Optional[str] = Security(api_key_header),
db: Session = Depends(get_db)
) -> User:
"""Get current user from JWT or API key"""
# Try JWT token first
if bearer:
try:
token = bearer.credentials
payload = decode_token(token)
user_id: int = payload.get("sub")
user = db.query(User).filter(User.id == user_id).first()
if user and user.is_active:
return user
except:
pass
# Try API key
if api_key:
user = await get_current_user_from_api_key(api_key, db)
if user:
return user
raise AuthenticationError("Authentication required")
# Dependency: Require superuser
async def get_current_superuser(
current_user: User = Depends(get_current_user_flexible)
) -> User:
"""Require superuser privileges"""
if not current_user.is_superuser:
raise AuthorizationError("Superuser privileges required")
return current_user
# Helper: Authenticate user
def authenticate_user(
db: Session,
email: str,
password: str
) -> Optional[User]:
"""Authenticate user with email and password"""
user = db.query(User).filter(User.email == email).first()
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
# Helper: Create user
def create_user(
db: Session,
email: str,
password: str,
full_name: Optional[str] = None,
is_superuser: bool = False
) -> User:
"""Create new user"""
# Check if user exists
existing_user = db.query(User).filter(User.email == email).first()
if existing_user:
raise ValueError("User with this email already exists")
# Create user
user = User(
email=email,
hashed_password=get_password_hash(password),
full_name=full_name,
is_superuser=is_superuser,
is_active=True
)
db.add(user)
db.commit()
db.refresh(user)
return user
# Helper: Create API key
def create_api_key_for_user(
db: Session,
user_id: int,
name: Optional[str] = None,
expires_days: Optional[int] = None
) -> APIKey:
"""Create API key for user"""
key = generate_api_key()
api_key = APIKey(
key=key,
name=name or "API Key",
user_id=user_id,
is_active=True,
rate_limit_per_minute=settings.RATE_LIMIT_PER_MINUTE,
rate_limit_per_hour=settings.RATE_LIMIT_PER_HOUR,
expires_at=datetime.utcnow() + timedelta(days=expires_days) if expires_days else None
)
db.add(api_key)
db.commit()
db.refresh(api_key)
return api_key
if __name__ == "__main__":
# Test password hashing
password = "test_password_123"
hashed = get_password_hash(password)
print(f"Hashed: {hashed}")
print(f"Verified: {verify_password(password, hashed)}")
# Test JWT token creation
token = create_access_token({"sub": 1, "email": "test@example.com"})
print(f"Token: {token}")
payload = decode_token(token)
print(f"Decoded: {payload}")
# Test API key generation
api_key = generate_api_key()
print(f"API Key: {api_key}")