Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Depends, HTTPException, status | |
| from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
| from sqlalchemy.orm import Session | |
| from jose import JWTError, jwt | |
| # Import local modules | |
| from . import models, schemas, database | |
| from .utils import verify_password, create_access_token, get_password_hash, SECRET_KEY, ALGORITHM | |
| # Create tables automatically | |
| models.Base.metadata.create_all(bind=database.engine) | |
| router = APIRouter( | |
| prefix="/auth", | |
| tags=["Authentication"], | |
| ) | |
| # Define the OAuth2 scheme | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") | |
| # ========================================== | |
| # π The Missing Dependency (Add this back!) | |
| # ========================================== | |
| async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(database.get_db)): | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| # Decode the token | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username: str = payload.get("sub") | |
| if username is None: | |
| raise credentials_exception | |
| except JWTError: | |
| raise credentials_exception | |
| # Check if user exists in DB | |
| user = db.query(models.User).filter(models.User.username == username).first() | |
| if user is None: | |
| raise credentials_exception | |
| return user | |
| # ========================================== | |
| # π Endpoints | |
| # ========================================== | |
| def register_user(user: schemas.UserCreate, db: Session = Depends(database.get_db)): | |
| # 1. Check if user already exists | |
| db_user = db.query(models.User).filter(models.User.username == user.username).first() | |
| if db_user: | |
| raise HTTPException(status_code=400, detail="Username already registered") | |
| # 2. Hash the password | |
| hashed_pw = get_password_hash(user.password) | |
| # 3. Save to DB | |
| new_user = models.User( | |
| username=user.username, | |
| email=user.email, | |
| hashed_password=hashed_pw | |
| ) | |
| db.add(new_user) | |
| db.commit() | |
| db.refresh(new_user) | |
| return new_user | |
| async def login_for_access_token( | |
| form_data: OAuth2PasswordRequestForm = Depends(), | |
| db: Session = Depends(database.get_db) | |
| ): | |
| # 1. Fetch user from DB | |
| user = db.query(models.User).filter(models.User.username == form_data.username).first() | |
| # 2. Verify password | |
| if not user or not verify_password(form_data.password, user.hashed_password): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect username or password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| # 3. Create Token | |
| access_token = create_access_token(data={"sub": user.username}) | |
| return {"access_token": access_token, "token_type": "bearer"} |