Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import APIRouter, Depends, HTTPException, status, Body | |
| from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
| from pydantic import BaseModel, EmailStr, Field | |
| from passlib.context import CryptContext | |
| from jose import JWTError, jwt | |
| from datetime import datetime, timedelta | |
| from motor.motor_asyncio import AsyncIOMotorClient | |
| from dotenv import load_dotenv | |
| # Import database collections and the PyObjectId type | |
| from .db import user_collection, PyObjectId | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # --- Configuration --- | |
| JWT_SECRET = os.getenv("JWT_SECRET") | |
| JWT_ALGORITHM = "HS256" | |
| JWT_EXPIRES_MINUTES = int(os.getenv("JWT_EXPIRES_MINUTES", 360)) | |
| # --- Hashing --- | |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
| # --- OAuth2 --- | |
| # This just defines the *scheme* for FastAPI's docs, not the implementation | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") | |
| # === UTILITY FUNCTIONS === | |
| def get_password_hash(password: str) -> str: | |
| """Hashes a plain text password.""" | |
| return pwd_context.hash(password) | |
| def verify_password(plain_password: str, hashed_password: str) -> bool: | |
| """Verifies a plain password against a hash.""" | |
| return pwd_context.verify(plain_password, hashed_password) | |
| def create_access_token(data: dict, expires_delta: timedelta = None) -> str: | |
| """Creates a new JWT access token.""" | |
| to_encode = data.copy() | |
| if expires_delta: | |
| expire = datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.utcnow() + timedelta(minutes=JWT_EXPIRES_MINUTES) | |
| to_encode.update({"exp": expire}) | |
| encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) | |
| return encoded_jwt | |
| # === PYDANTIC MODELS (Data Validation) === | |
| class Token(BaseModel): | |
| access_token: str | |
| token_type: str | |
| class UserCreate(BaseModel): | |
| email: EmailStr | |
| password: str | |
| class UserInDB(BaseModel): | |
| id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") | |
| email: EmailStr | |
| password_hash: str | |
| created_at: datetime = datetime.now() | |
| class Config: | |
| # V2 replacement for 'allow_population_by_field_name': | |
| validate_by_name = True | |
| arbitrary_types_allowed = True | |
| json_encoders = {PyObjectId: str} | |
| async def get_current_active_user(token: str = Depends(oauth2_scheme)): | |
| """ | |
| Validates the JWT token and returns the user's data (email). | |
| This function will be used as a dependency to protect endpoints. | |
| """ | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| # Decode the token | |
| payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) | |
| email: str = payload.get("sub") # 'sub' is the email we stored | |
| if email is None: | |
| raise credentials_exception | |
| except JWTError: | |
| raise credentials_exception | |
| # Find the user in the database | |
| user = await user_collection.find_one({"email": email}) | |
| if user is None: | |
| raise credentials_exception | |
| # Return the user's data (or just the email, as needed) | |
| return user | |
| # === API ROUTER === | |
| AuthRouter = APIRouter( | |
| prefix="/api/auth", | |
| tags=["Authentication"] # Groups endpoints in the /docs | |
| ) | |
| async def register_user(user_in: UserCreate = Body(...)): | |
| """ | |
| Admin creates the first account. | |
| (As per spec, this is open, but in future you might lock it down) | |
| """ | |
| # Check if user already exists | |
| existing_user = await user_collection.find_one({"email": user_in.email}) | |
| if existing_user: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Email already registered." | |
| ) | |
| # Hash the password | |
| hashed_password = get_password_hash(user_in.password) | |
| # Create user data for insertion | |
| new_user = { | |
| "email": user_in.email, | |
| "password_hash": hashed_password, | |
| "created_at": datetime.utcnow() | |
| } | |
| # Insert new user into database | |
| result = await user_collection.insert_one(new_user) | |
| if not result.inserted_id: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to create user." | |
| ) | |
| return { | |
| "status": "success", | |
| "message": "User created successfully", | |
| "user_id": str(result.inserted_id) | |
| } | |
| async def login_for_access_token( | |
| # FastAPI's OAuth2 form helper | |
| # It expects form data: 'username' and 'password' | |
| form_data: OAuth2PasswordRequestForm = Depends() | |
| ): | |
| """ | |
| Logs in a user and returns a JWT token. | |
| """ | |
| # Note: form_data.username is the 'email' | |
| user = await user_collection.find_one({"email": form_data.username}) | |
| # Check if user exists and password is correct | |
| if not user or not verify_password(form_data.password, user["password_hash"]): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect email or password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| # User is valid, create a token | |
| access_token = create_access_token( | |
| data={"sub": user["email"]} # 'sub' (subject) is the user's email | |
| ) | |
| return { | |
| "access_token": access_token, | |
| "token_type": "bearer" | |
| } |