sceneweaver / auth.py
mung-bean's picture
pls work
2b93829
import os
from sqlalchemy.orm import Session
from models import User
from schemas import UserCreate
from passlib.context import CryptContext
from jose import jwt, JWTError
from datetime import datetime, timedelta, timezone
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from typing import Union
from dotenv import load_dotenv
load_dotenv()
def get_env_var(name):
val = os.getenv(name)
return val if val else None
SECRET_KEY = get_env_var("SECRET_KEY")
ALGORITHM = get_env_var("ALGORITHM")
ACCESS_TOKEN_EXPIRE_MINUTES = get_env_var("ACCESS_TOKEN_EXPIRE_MINUTES")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/token")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_password_hash(password):
return pwd_context.hash(password)
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_user_by_username(db: Session, username: str):
return db.query(User).filter(User.username == username).first()
def get_user_by_email(db: Session, email: str):
return db.query(User).filter(User.email == email).first()
def create_user(db: Session, user: UserCreate):
hashed_password = get_password_hash(user.password)
db_user = User(
username=user.username, email=user.email, hashed_password=hashed_password
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
def get_current_user(db: Session, token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = get_user_by_username(db, username)
if user is None:
raise credentials_exception
return user
def authenticate_user(db: Session, username: str, password: str):
user = get_user_by_username(db, username)
if not user:
return False
if not pwd_context.verify(password, user.hashed_password):
return False
return user
def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm="HS256")
return encoded_jwt
def verify_token(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)
return payload
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)
def verify_token_string(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)
return username
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)