Spaces:
Sleeping
Sleeping
File size: 5,778 Bytes
cfd8098 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import os
from fastapi import APIRouter, Depends, HTTPException, status, Body
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr, Field
from passlib.context import CryptContext
from jose import JWTError, jwt
from datetime import datetime, timedelta
from motor.motor_asyncio import AsyncIOMotorClient
from dotenv import load_dotenv
# Import database collections and the PyObjectId type
from .db import user_collection, PyObjectId
# Load environment variables from .env file
load_dotenv()
# --- Configuration ---
JWT_SECRET = os.getenv("JWT_SECRET")
JWT_ALGORITHM = "HS256"
JWT_EXPIRES_MINUTES = int(os.getenv("JWT_EXPIRES_MINUTES", 360))
# --- Hashing ---
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# --- OAuth2 ---
# This just defines the *scheme* for FastAPI's docs, not the implementation
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
# === UTILITY FUNCTIONS ===
def get_password_hash(password: str) -> str:
"""Hashes a plain text password."""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies a plain password against a hash."""
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
"""Creates a new JWT access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=JWT_EXPIRES_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
return encoded_jwt
# === PYDANTIC MODELS (Data Validation) ===
class Token(BaseModel):
access_token: str
token_type: str
class UserCreate(BaseModel):
email: EmailStr
password: str
class UserInDB(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
email: EmailStr
password_hash: str
created_at: datetime = datetime.now()
class Config:
# V2 replacement for 'allow_population_by_field_name':
validate_by_name = True
arbitrary_types_allowed = True
json_encoders = {PyObjectId: str}
async def get_current_active_user(token: str = Depends(oauth2_scheme)):
"""
Validates the JWT token and returns the user's data (email).
This function will be used as a dependency to protect endpoints.
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# Decode the token
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
email: str = payload.get("sub") # 'sub' is the email we stored
if email is None:
raise credentials_exception
except JWTError:
raise credentials_exception
# Find the user in the database
user = await user_collection.find_one({"email": email})
if user is None:
raise credentials_exception
# Return the user's data (or just the email, as needed)
return user
# === API ROUTER ===
AuthRouter = APIRouter(
prefix="/api/auth",
tags=["Authentication"] # Groups endpoints in the /docs
)
@AuthRouter.post("/register", status_code=status.HTTP_201_CREATED)
async def register_user(user_in: UserCreate = Body(...)):
"""
Admin creates the first account.
(As per spec, this is open, but in future you might lock it down)
"""
# Check if user already exists
existing_user = await user_collection.find_one({"email": user_in.email})
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered."
)
# Hash the password
hashed_password = get_password_hash(user_in.password)
# Create user data for insertion
new_user = {
"email": user_in.email,
"password_hash": hashed_password,
"created_at": datetime.utcnow()
}
# Insert new user into database
result = await user_collection.insert_one(new_user)
if not result.inserted_id:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create user."
)
return {
"status": "success",
"message": "User created successfully",
"user_id": str(result.inserted_id)
}
@AuthRouter.post("/login", response_model=Token)
async def login_for_access_token(
# FastAPI's OAuth2 form helper
# It expects form data: 'username' and 'password'
form_data: OAuth2PasswordRequestForm = Depends()
):
"""
Logs in a user and returns a JWT token.
"""
# Note: form_data.username is the 'email'
user = await user_collection.find_one({"email": form_data.username})
# Check if user exists and password is correct
if not user or not verify_password(form_data.password, user["password_hash"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
# User is valid, create a token
access_token = create_access_token(
data={"sub": user["email"]} # 'sub' (subject) is the user's email
)
return {
"access_token": access_token,
"token_type": "bearer"
} |