Spaces:
Sleeping
Sleeping
File size: 3,083 Bytes
438ec1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from jose import JWTError, jwt
# Import local modules
from . import models, schemas, database
from .utils import verify_password, create_access_token, get_password_hash, SECRET_KEY, ALGORITHM
# Create tables automatically
models.Base.metadata.create_all(bind=database.engine)
router = APIRouter(
prefix="/auth",
tags=["Authentication"],
)
# Define the OAuth2 scheme
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
# ==========================================
# π The Missing Dependency (Add this back!)
# ==========================================
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(database.get_db)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# Decode the token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
# Check if user exists in DB
user = db.query(models.User).filter(models.User.username == username).first()
if user is None:
raise credentials_exception
return user
# ==========================================
# π Endpoints
# ==========================================
@router.post("/register", response_model=schemas.UserOut)
def register_user(user: schemas.UserCreate, db: Session = Depends(database.get_db)):
# 1. Check if user already exists
db_user = db.query(models.User).filter(models.User.username == user.username).first()
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
# 2. Hash the password
hashed_pw = get_password_hash(user.password)
# 3. Save to DB
new_user = models.User(
username=user.username,
email=user.email,
hashed_password=hashed_pw
)
db.add(new_user)
db.commit()
db.refresh(new_user)
return new_user
@router.post("/token")
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(database.get_db)
):
# 1. Fetch user from DB
user = db.query(models.User).filter(models.User.username == form_data.username).first()
# 2. Verify password
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
# 3. Create Token
access_token = create_access_token(data={"sub": user.username})
return {"access_token": access_token, "token_type": "bearer"} |