Spaces:
Paused
Paused
| from fastapi import APIRouter, Depends, HTTPException, status, Form, Body | |
| from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from sqlalchemy import select | |
| from sqlalchemy.orm import selectinload | |
| from ..core.security import create_access_token, verify_password, get_password_hash | |
| from ..db.database import get_db | |
| from ..db.models import User | |
| from ..db.schemas import UserCreate, UserInDB, LoginData | |
| from datetime import timedelta | |
| from typing import Any | |
| router = APIRouter() | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
| async def login_form( | |
| form_data: OAuth2PasswordRequestForm = Depends(), | |
| db: AsyncSession = Depends(get_db) | |
| ) -> Any: | |
| return await authenticate_user(db, form_data.username, form_data.password) | |
| async def login_json( | |
| login_data: LoginData, | |
| db: AsyncSession = Depends(get_db) | |
| ) -> Any: | |
| return await authenticate_user(db, login_data.email, login_data.password) | |
| async def authenticate_user(db: AsyncSession, email: str, password: str) -> dict: | |
| stmt = select(User).where(User.email == email) | |
| result = await db.execute(stmt) | |
| user = result.scalar_one_or_none() | |
| if not user: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect email or password", | |
| ) | |
| if not verify_password(password, user.hashed_password): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect email or password", | |
| ) | |
| access_token = create_access_token(user.id) | |
| return {"access_token": access_token, "token_type": "bearer"} | |
| async def register( | |
| user_data: UserCreate, | |
| db: AsyncSession = Depends(get_db) | |
| ) -> Any: | |
| # Check if user exists by email | |
| stmt = select(User).where(User.email == user_data.email) | |
| result = await db.execute(stmt) | |
| if result.scalar_one_or_none(): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Email already registered", | |
| ) | |
| # Extract username from email if not provided | |
| username = user_data.username or user_data.email.split('@')[0] | |
| # Create new user | |
| user = User( | |
| email=user_data.email, | |
| username=username, | |
| full_name=user_data.full_name, | |
| hashed_password=get_password_hash(user_data.password), | |
| is_active=user_data.is_active, | |
| is_superuser=user_data.is_superuser, | |
| branch_id=user_data.branch_id | |
| ) | |
| db.add(user) | |
| await db.commit() | |
| # Refresh user with roles relationship loaded | |
| stmt = select(User).options(selectinload(User.roles)).where(User.id == user.id) | |
| result = await db.execute(stmt) | |
| user = result.scalar_one() | |
| return user | |