from fastapi import APIRouter, Depends, HTTPException, status, Form, Body from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from sqlalchemy.orm import selectinload from ..core.security import create_access_token, verify_password, get_password_hash from ..db.database import get_db from ..db.models import User from ..db.schemas import UserCreate, UserInDB, LoginData from datetime import timedelta from typing import Any router = APIRouter() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @router.post("/login/form") async def login_form( form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db) ) -> Any: return await authenticate_user(db, form_data.username, form_data.password) @router.post("/login") async def login_json( login_data: LoginData, db: AsyncSession = Depends(get_db) ) -> Any: return await authenticate_user(db, login_data.email, login_data.password) async def authenticate_user(db: AsyncSession, email: str, password: str) -> dict: stmt = select(User).where(User.email == email) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", ) if not verify_password(password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", ) access_token = create_access_token(user.id) return {"access_token": access_token, "token_type": "bearer"} @router.post("/register", response_model=UserInDB) async def register( user_data: UserCreate, db: AsyncSession = Depends(get_db) ) -> Any: # Check if user exists by email stmt = select(User).where(User.email == user_data.email) result = await db.execute(stmt) if result.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered", ) # Extract username from email if not provided username = user_data.username or user_data.email.split('@')[0] # Create new user user = User( email=user_data.email, username=username, full_name=user_data.full_name, hashed_password=get_password_hash(user_data.password), is_active=user_data.is_active, is_superuser=user_data.is_superuser, branch_id=user_data.branch_id ) db.add(user) await db.commit() # Refresh user with roles relationship loaded stmt = select(User).options(selectinload(User.roles)).where(User.id == user.id) result = await db.execute(stmt) user = result.scalar_one() return user