Spaces:
Running
Running
File size: 7,506 Bytes
38c555b e631be8 e3c16ea 724ff38 da8dc09 38c555b e631be8 38c555b da8dc09 9b92ec5 00e8345 38c555b da8dc09 9b92ec5 00e8345 38c555b 00e8345 da8dc09 e3c16ea da8dc09 9b92ec5 00e8345 38c555b 00e8345 e3c16ea 00e8345 38c555b 00e8345 38c555b e631be8 da8dc09 38c555b 00e8345 38c555b 00e8345 da8dc09 9b92ec5 38c555b da8dc09 9b92ec5 00e8345 38c555b e631be8 da8dc09 9b92ec5 00e8345 38c555b 00e8345 da8dc09 e3c16ea da8dc09 9b92ec5 00e8345 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
from passlib.context import CryptContext
from jose import JWTError, jwt
from datetime import datetime, timedelta
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy import func
from sqlalchemy.orm import Session,Mapped
from db.database import get_db
from db.models import User
from interfaces.authModels import UserResponse,TokenData
from logger_manager import log_info, log_error
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "09d8f7a6b5c4e3d2f1a0b9c8d7e6f5a4"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Create an optional OAuth2 scheme that doesn't auto-error
oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def verify_password(plain_password, hashed_password):
log_info("Verifying password")
try:
return pwd_context.verify(plain_password, hashed_password)
except Exception as e:
log_error(f"Error verifying password: {str(e)}",e)
raise HTTPException(status_code=500, detail=str(e))
def get_password_hash(password):
log_info("Hashing password")
try:
return pwd_context.hash(password)
except Exception as e:
log_error(f"Error hashing password: {str(e)}",e)
raise HTTPException(status_code=500, detail=str(e))
def get_user(db, email: str):
log_info(f"Getting user: {email}")
try:
return db.query(User).filter(func.lower(User.email) == email.lower()).first()
except Exception as e:
log_error(f"Error getting user: {str(e)}",e)
raise HTTPException(status_code=500, detail=str(e))
def authenticate_user(db: Session, username: str, password: str):
user = db.query(User).filter(func.lower(User.email) == username.lower()).first()
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# New flexible token extractor
async def get_token_from_request(request: Request = None, oauth_token: str = None):
"""Extract token from various sources, prioritizing standard formats but
supporting Hugging Face Spaces custom headers"""
# First try the standard OAuth2 token if provided
if oauth_token:
return oauth_token
if request is None:
return None
# Try standard Authorization header (works in local development)
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.replace("Bearer ", "")
# Try Hugging Face's custom header
hf_token = request.headers.get("x-ip-token")
if hf_token:
log_info(f"Using token from Hugging Face x-ip-token header")
return hf_token
# Final fallback: check query parameters
token_param = request.query_params.get("token")
if token_param:
log_info(f"Using token from query parameter")
return token_param
return None
# Replace or add this function
async def get_current_user(
request: Request,
db: Session = Depends(get_db),
oauth_token: str = Depends(oauth2_scheme_optional)
):
"""Enhanced user authentication that supports both standard OAuth2
and Hugging Face Spaces deployments"""
log_info("Getting current user with flexible auth")
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Get token from any available source
token = await get_token_from_request(request, oauth_token)
if not token:
log_error("No authentication token found")
raise credentials_exception
try:
# Try to decode the token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email: str = payload.get("sub")
if email is None:
log_error("Token missing 'sub' claim")
raise credentials_exception
token_data = TokenData(email=email)
except JWTError as e:
log_error(f"JWT verification failed: {str(e)}", e)
raise credentials_exception
except Exception as e:
log_error(f"Token processing error: {str(e)}", e)
raise HTTPException(status_code=500, detail=str(e))
# Find the user
user = get_user(db, email=token_data.email)
if user is None:
log_error(f"User not found: {token_data.email}")
raise credentials_exception
return user
# Add this function for active users with flexible auth
async def get_current_active_user(
request: Request,
db: Session = Depends(get_db),
oauth_token: str = Depends(oauth2_scheme_optional)
):
"""Get active user with flexible authentication"""
current_user = await get_current_user(request, db, oauth_token)
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return UserResponse.from_orm(current_user)
async def get_current_user_old(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
log_info("Getting current user")
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email: str = payload.get("sub")
if email is None:
raise credentials_exception
token_data = TokenData(email=email)
except JWTError as e:
log_error(f"JWT error: {str(e)}",e)
raise credentials_exception
except Exception as e:
log_error(f"Error decoding token: {str(e)}",e)
raise HTTPException(status_code=500, detail=str(e))
user = get_user(db, email=token_data.email)
if user is None:
raise credentials_exception
return user
async def get_current_active_user_old(current_user: User = Depends(get_current_user_old)):
log_info("Getting current active user")
try:
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return UserResponse.from_orm(current_user)
except Exception as e:
log_error(f"Error getting current active user: {str(e)}",e)
raise HTTPException(status_code=500, detail=str(e))
def create_user(db: Session, name: str, email: str, password: str):
log_info(f"Creating user: {name}")
try:
hashed_password = get_password_hash(password)
db_user = User(name=name, email=email.lower(), hashed_password=hashed_password)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
except Exception as e:
log_error(f"Error creating user: {str(e)}",e)
raise HTTPException(status_code=500, detail=str(e)) |