rohanshaw's picture
Upload 8 files
cfd8098 verified
import os
from fastapi import APIRouter, Depends, HTTPException, status, Body
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr, Field
from passlib.context import CryptContext
from jose import JWTError, jwt
from datetime import datetime, timedelta
from motor.motor_asyncio import AsyncIOMotorClient
from dotenv import load_dotenv
# Import database collections and the PyObjectId type
from .db import user_collection, PyObjectId
# Load environment variables from .env file
load_dotenv()
# --- Configuration ---
JWT_SECRET = os.getenv("JWT_SECRET")
JWT_ALGORITHM = "HS256"
JWT_EXPIRES_MINUTES = int(os.getenv("JWT_EXPIRES_MINUTES", 360))
# --- Hashing ---
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# --- OAuth2 ---
# This just defines the *scheme* for FastAPI's docs, not the implementation
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
# === UTILITY FUNCTIONS ===
def get_password_hash(password: str) -> str:
"""Hashes a plain text password."""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies a plain password against a hash."""
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
"""Creates a new JWT access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=JWT_EXPIRES_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
return encoded_jwt
# === PYDANTIC MODELS (Data Validation) ===
class Token(BaseModel):
access_token: str
token_type: str
class UserCreate(BaseModel):
email: EmailStr
password: str
class UserInDB(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
email: EmailStr
password_hash: str
created_at: datetime = datetime.now()
class Config:
# V2 replacement for 'allow_population_by_field_name':
validate_by_name = True
arbitrary_types_allowed = True
json_encoders = {PyObjectId: str}
async def get_current_active_user(token: str = Depends(oauth2_scheme)):
"""
Validates the JWT token and returns the user's data (email).
This function will be used as a dependency to protect endpoints.
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# Decode the token
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
email: str = payload.get("sub") # 'sub' is the email we stored
if email is None:
raise credentials_exception
except JWTError:
raise credentials_exception
# Find the user in the database
user = await user_collection.find_one({"email": email})
if user is None:
raise credentials_exception
# Return the user's data (or just the email, as needed)
return user
# === API ROUTER ===
AuthRouter = APIRouter(
prefix="/api/auth",
tags=["Authentication"] # Groups endpoints in the /docs
)
@AuthRouter.post("/register", status_code=status.HTTP_201_CREATED)
async def register_user(user_in: UserCreate = Body(...)):
"""
Admin creates the first account.
(As per spec, this is open, but in future you might lock it down)
"""
# Check if user already exists
existing_user = await user_collection.find_one({"email": user_in.email})
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered."
)
# Hash the password
hashed_password = get_password_hash(user_in.password)
# Create user data for insertion
new_user = {
"email": user_in.email,
"password_hash": hashed_password,
"created_at": datetime.utcnow()
}
# Insert new user into database
result = await user_collection.insert_one(new_user)
if not result.inserted_id:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create user."
)
return {
"status": "success",
"message": "User created successfully",
"user_id": str(result.inserted_id)
}
@AuthRouter.post("/login", response_model=Token)
async def login_for_access_token(
# FastAPI's OAuth2 form helper
# It expects form data: 'username' and 'password'
form_data: OAuth2PasswordRequestForm = Depends()
):
"""
Logs in a user and returns a JWT token.
"""
# Note: form_data.username is the 'email'
user = await user_collection.find_one({"email": form_data.username})
# Check if user exists and password is correct
if not user or not verify_password(form_data.password, user["password_hash"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
# User is valid, create a token
access_token = create_access_token(
data={"sub": user["email"]} # 'sub' (subject) is the user's email
)
return {
"access_token": access_token,
"token_type": "bearer"
}