from datetime import datetime, timedelta from typing import Optional from jose import JWTError, jwt from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, APIKeyHeader from sqlalchemy.orm import Session import os from app.core.config import settings from app.db.base import get_db from app.db.models import User from app.models.token import TokenData # OAuth2 scheme for token authentication oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login") # API Key security scheme API_KEY_NAME = "X-API-Key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) # Use API key from settings async def get_api_key_user( api_key: str = Depends(api_key_header), ) -> bool: """ Validate the API key from the request header. Args: api_key: The API key from the request header Returns: bool: True if the API key is valid Raises: HTTPException: If the API key is invalid """ if not api_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="API key required", headers={"WWW-Authenticate": "ApiKey"}, ) if api_key != settings.API_KEY: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key", headers={"WWW-Authenticate": "ApiKey"}, ) return True def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """ Create a new JWT access token. Args: data: The data to encode in the token expires_delta: Optional expiration time delta Returns: str: The encoded JWT token """ to_encode = data.copy() # Set expiration time if expires_delta: expire = datetime.utcnow() + expires_delta else: # Use configured expiration time from settings expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) # Create the JWT token encoded_jwt = jwt.encode( to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM ) return encoded_jwt async def get_current_user( token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: """ Get the current authenticated user from the token. Args: token: The JWT token db: Database session Returns: User: The authenticated user Raises: HTTPException: If authentication fails """ credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: # Decode the JWT token payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) # Extract user_id from token user_id: str = payload.get("sub") if user_id is None: raise credentials_exception token_data = TokenData(user_id=user_id) except JWTError as e: # Log the specific JWT error for debugging print(f"JWT validation error: {str(e)}") # If it's a signature verification failure, return a specific error if "signature" in str(e).lower() or "invalid" in str(e).lower(): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token signature", headers={"WWW-Authenticate": "Bearer"}, ) raise credentials_exception # Get the user from the database user = db.query(User).filter(User.id == token_data.user_id).first() if user is None: print(f"User not found in database: {token_data.user_id}") raise credentials_exception # Check if token is expired try: exp = payload.get("exp") if exp is None: print(f"Token has no expiration: {user.id}") raise credentials_exception expiry_time = datetime.fromtimestamp(exp) current_time = datetime.utcnow() # Add detailed logging for token expiration time_until_expiry = expiry_time - current_time print(f"Token expiration check: current={current_time}, expiry={expiry_time}, seconds_remaining={time_until_expiry.total_seconds()}") if expiry_time < current_time: print(f"Token expired for user: {user.id}, expired at {expiry_time}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired", headers={"WWW-Authenticate": "Bearer"}, ) except Exception as e: print(f"Error checking token expiration: {str(e)}") raise credentials_exception # Check if user is active if not user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" ) return user async def get_current_active_user( current_user: User = Depends(get_current_user) ) -> User: """ Get the current active user. Args: current_user: The current authenticated user Returns: User: The current active user Raises: HTTPException: If the user is inactive """ if not current_user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" ) return current_user