import os from fastapi import APIRouter, Depends, HTTPException, status, Body from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import BaseModel, EmailStr, Field from passlib.context import CryptContext from jose import JWTError, jwt from datetime import datetime, timedelta from motor.motor_asyncio import AsyncIOMotorClient from dotenv import load_dotenv # Import database collections and the PyObjectId type from .db import user_collection, PyObjectId # Load environment variables from .env file load_dotenv() # --- Configuration --- JWT_SECRET = os.getenv("JWT_SECRET") JWT_ALGORITHM = "HS256" JWT_EXPIRES_MINUTES = int(os.getenv("JWT_EXPIRES_MINUTES", 360)) # --- Hashing --- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # --- OAuth2 --- # This just defines the *scheme* for FastAPI's docs, not the implementation oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") # === UTILITY FUNCTIONS === def get_password_hash(password: str) -> str: """Hashes a plain text password.""" return pwd_context.hash(password) def verify_password(plain_password: str, hashed_password: str) -> bool: """Verifies a plain password against a hash.""" return pwd_context.verify(plain_password, hashed_password) def create_access_token(data: dict, expires_delta: timedelta = None) -> str: """Creates a new JWT access token.""" to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=JWT_EXPIRES_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) return encoded_jwt # === PYDANTIC MODELS (Data Validation) === class Token(BaseModel): access_token: str token_type: str class UserCreate(BaseModel): email: EmailStr password: str class UserInDB(BaseModel): id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") email: EmailStr password_hash: str created_at: datetime = datetime.now() class Config: # V2 replacement for 'allow_population_by_field_name': validate_by_name = True arbitrary_types_allowed = True json_encoders = {PyObjectId: str} async def get_current_active_user(token: str = Depends(oauth2_scheme)): """ Validates the JWT token and returns the user's data (email). This function will be used as a dependency to protect endpoints. """ credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: # Decode the token payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) email: str = payload.get("sub") # 'sub' is the email we stored if email is None: raise credentials_exception except JWTError: raise credentials_exception # Find the user in the database user = await user_collection.find_one({"email": email}) if user is None: raise credentials_exception # Return the user's data (or just the email, as needed) return user # === API ROUTER === AuthRouter = APIRouter( prefix="/api/auth", tags=["Authentication"] # Groups endpoints in the /docs ) @AuthRouter.post("/register", status_code=status.HTTP_201_CREATED) async def register_user(user_in: UserCreate = Body(...)): """ Admin creates the first account. (As per spec, this is open, but in future you might lock it down) """ # Check if user already exists existing_user = await user_collection.find_one({"email": user_in.email}) if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered." ) # Hash the password hashed_password = get_password_hash(user_in.password) # Create user data for insertion new_user = { "email": user_in.email, "password_hash": hashed_password, "created_at": datetime.utcnow() } # Insert new user into database result = await user_collection.insert_one(new_user) if not result.inserted_id: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create user." ) return { "status": "success", "message": "User created successfully", "user_id": str(result.inserted_id) } @AuthRouter.post("/login", response_model=Token) async def login_for_access_token( # FastAPI's OAuth2 form helper # It expects form data: 'username' and 'password' form_data: OAuth2PasswordRequestForm = Depends() ): """ Logs in a user and returns a JWT token. """ # Note: form_data.username is the 'email' user = await user_collection.find_one({"email": form_data.username}) # Check if user exists and password is correct if not user or not verify_password(form_data.password, user["password_hash"]): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", headers={"WWW-Authenticate": "Bearer"}, ) # User is valid, create a token access_token = create_access_token( data={"sub": user["email"]} # 'sub' (subject) is the user's email ) return { "access_token": access_token, "token_type": "bearer" }