Spaces:
Paused
Paused
Update app/api/auth.py
Browse files- app/api/auth.py +86 -82
app/api/auth.py
CHANGED
|
@@ -1,82 +1,86 @@
|
|
| 1 |
-
from fastapi import APIRouter, Depends, HTTPException, status, Form, Body
|
| 2 |
-
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
| 3 |
-
from sqlalchemy.ext.asyncio import AsyncSession
|
| 4 |
-
from sqlalchemy import select
|
| 5 |
-
from sqlalchemy.orm import selectinload
|
| 6 |
-
from ..core.security import create_access_token, verify_password, get_password_hash
|
| 7 |
-
from ..db.database import get_db
|
| 8 |
-
from ..db.models import User
|
| 9 |
-
from ..db.schemas import UserCreate, UserInDB, LoginData
|
| 10 |
-
from datetime import timedelta
|
| 11 |
-
from typing import Any
|
| 12 |
-
|
| 13 |
-
router = APIRouter()
|
| 14 |
-
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
| 15 |
-
|
| 16 |
-
@router.post("/login/form")
|
| 17 |
-
async def login_form(
|
| 18 |
-
form_data: OAuth2PasswordRequestForm = Depends(),
|
| 19 |
-
db: AsyncSession = Depends(get_db)
|
| 20 |
-
) -> Any:
|
| 21 |
-
return await authenticate_user(db, form_data.username, form_data.password)
|
| 22 |
-
|
| 23 |
-
@router.post("/login")
|
| 24 |
-
async def login_json(
|
| 25 |
-
login_data: LoginData,
|
| 26 |
-
db: AsyncSession = Depends(get_db)
|
| 27 |
-
) -> Any:
|
| 28 |
-
return await authenticate_user(db, login_data.email, login_data.password)
|
| 29 |
-
|
| 30 |
-
async def authenticate_user(db: AsyncSession, email: str, password: str) -> dict:
|
| 31 |
-
stmt = select(User).where(User.email == email)
|
| 32 |
-
result = await db.execute(stmt)
|
| 33 |
-
user = result.scalar_one_or_none()
|
| 34 |
-
|
| 35 |
-
if not user:
|
| 36 |
-
raise HTTPException(
|
| 37 |
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 38 |
-
detail="Incorrect email or password",
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
if not verify_password(password, user.hashed_password):
|
| 42 |
-
raise HTTPException(
|
| 43 |
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 44 |
-
detail="Incorrect email or password",
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
access_token = create_access_token(user.id)
|
| 48 |
-
return {"access_token": access_token, "token_type": "bearer"}
|
| 49 |
-
|
| 50 |
-
@router.post("/register", response_model=UserInDB)
|
| 51 |
-
async def register(
|
| 52 |
-
user_data: UserCreate,
|
| 53 |
-
db: AsyncSession = Depends(get_db)
|
| 54 |
-
) -> Any:
|
| 55 |
-
# Check if user exists by email
|
| 56 |
-
stmt = select(User).where(User.email == user_data.email)
|
| 57 |
-
result = await db.execute(stmt)
|
| 58 |
-
if result.scalar_one_or_none():
|
| 59 |
-
raise HTTPException(
|
| 60 |
-
status_code=status.HTTP_400_BAD_REQUEST,
|
| 61 |
-
detail="Email already registered",
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
#
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Form, Body
|
| 2 |
+
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
| 3 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 4 |
+
from sqlalchemy import select
|
| 5 |
+
from sqlalchemy.orm import selectinload
|
| 6 |
+
from ..core.security import create_access_token, verify_password, get_password_hash
|
| 7 |
+
from ..db.database import get_db
|
| 8 |
+
from ..db.models import User
|
| 9 |
+
from ..db.schemas import UserCreate, UserInDB, LoginData
|
| 10 |
+
from datetime import timedelta
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
router = APIRouter()
|
| 14 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
| 15 |
+
|
| 16 |
+
@router.post("/login/form")
|
| 17 |
+
async def login_form(
|
| 18 |
+
form_data: OAuth2PasswordRequestForm = Depends(),
|
| 19 |
+
db: AsyncSession = Depends(get_db)
|
| 20 |
+
) -> Any:
|
| 21 |
+
return await authenticate_user(db, form_data.username, form_data.password)
|
| 22 |
+
|
| 23 |
+
@router.post("/login")
|
| 24 |
+
async def login_json(
|
| 25 |
+
login_data: LoginData,
|
| 26 |
+
db: AsyncSession = Depends(get_db)
|
| 27 |
+
) -> Any:
|
| 28 |
+
return await authenticate_user(db, login_data.email, login_data.password)
|
| 29 |
+
|
| 30 |
+
async def authenticate_user(db: AsyncSession, email: str, password: str) -> dict:
|
| 31 |
+
stmt = select(User).where(User.email == email)
|
| 32 |
+
result = await db.execute(stmt)
|
| 33 |
+
user = result.scalar_one_or_none()
|
| 34 |
+
|
| 35 |
+
if not user:
|
| 36 |
+
raise HTTPException(
|
| 37 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 38 |
+
detail="Incorrect email or password",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
if not verify_password(password, user.hashed_password):
|
| 42 |
+
raise HTTPException(
|
| 43 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 44 |
+
detail="Incorrect email or password",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
access_token = create_access_token(user.id)
|
| 48 |
+
return {"access_token": access_token, "token_type": "bearer"}
|
| 49 |
+
|
| 50 |
+
@router.post("/register", response_model=UserInDB)
|
| 51 |
+
async def register(
|
| 52 |
+
user_data: UserCreate,
|
| 53 |
+
db: AsyncSession = Depends(get_db)
|
| 54 |
+
) -> Any:
|
| 55 |
+
# Check if user exists by email
|
| 56 |
+
stmt = select(User).where(User.email == user_data.email)
|
| 57 |
+
result = await db.execute(stmt)
|
| 58 |
+
if result.scalar_one_or_none():
|
| 59 |
+
raise HTTPException(
|
| 60 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 61 |
+
detail="Email already registered",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Extract username from email if not provided
|
| 65 |
+
username = user_data.username or user_data.email.split('@')[0]
|
| 66 |
+
|
| 67 |
+
# Create new user
|
| 68 |
+
user = User(
|
| 69 |
+
email=user_data.email,
|
| 70 |
+
username=username,
|
| 71 |
+
full_name=user_data.full_name,
|
| 72 |
+
hashed_password=get_password_hash(user_data.password),
|
| 73 |
+
is_active=user_data.is_active,
|
| 74 |
+
is_superuser=user_data.is_superuser,
|
| 75 |
+
branch_id=user_data.branch_id
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
db.add(user)
|
| 79 |
+
await db.commit()
|
| 80 |
+
|
| 81 |
+
# Refresh user with roles relationship loaded
|
| 82 |
+
stmt = select(User).options(selectinload(User.roles)).where(User.id == user.id)
|
| 83 |
+
result = await db.execute(stmt)
|
| 84 |
+
user = result.scalar_one()
|
| 85 |
+
|
| 86 |
+
return user
|