File size: 7,506 Bytes
38c555b
 
 
e631be8
 
e3c16ea
724ff38
da8dc09
 
 
 
38c555b
 
 
 
 
 
 
 
 
 
e631be8
 
38c555b
 
da8dc09
 
 
 
9b92ec5
00e8345
38c555b
 
da8dc09
 
 
 
9b92ec5
00e8345
38c555b
00e8345
 
da8dc09
e3c16ea
da8dc09
9b92ec5
00e8345
38c555b
00e8345
e3c16ea
00e8345
 
 
 
 
38c555b
 
00e8345
 
 
 
 
 
 
 
38c555b
e631be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da8dc09
38c555b
 
 
 
 
 
 
00e8345
 
38c555b
00e8345
da8dc09
9b92ec5
38c555b
da8dc09
9b92ec5
00e8345
 
38c555b
 
 
 
e631be8
da8dc09
 
 
 
 
 
9b92ec5
00e8345
38c555b
00e8345
 
da8dc09
 
e3c16ea
da8dc09
 
 
 
 
9b92ec5
00e8345
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from passlib.context import CryptContext
from jose import JWTError, jwt
from datetime import datetime, timedelta
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy import func
from sqlalchemy.orm import Session,Mapped
from db.database import get_db
from db.models import User
from interfaces.authModels import UserResponse,TokenData
from logger_manager import log_info, log_error

# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "09d8f7a6b5c4e3d2f1a0b9c8d7e6f5a4"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

# Create an optional OAuth2 scheme that doesn't auto-error
oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)

def verify_password(plain_password, hashed_password):
    log_info("Verifying password")
    try:
        return pwd_context.verify(plain_password, hashed_password)
    except Exception as e:
        log_error(f"Error verifying password: {str(e)}",e)
        raise HTTPException(status_code=500, detail=str(e))

def get_password_hash(password):
    log_info("Hashing password")
    try:
        return pwd_context.hash(password)
    except Exception as e:
        log_error(f"Error hashing password: {str(e)}",e)
        raise HTTPException(status_code=500, detail=str(e))

def get_user(db, email: str):
    log_info(f"Getting user: {email}")
    try:
        return db.query(User).filter(func.lower(User.email) == email.lower()).first()
    except Exception as e:
        log_error(f"Error getting user: {str(e)}",e)
        raise HTTPException(status_code=500, detail=str(e))

def authenticate_user(db: Session, username: str, password: str):
    user = db.query(User).filter(func.lower(User.email) == username.lower()).first()
    if not user:
        return None
    if not verify_password(password, user.hashed_password):
        return None
    return user

def create_access_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

# New flexible token extractor
async def get_token_from_request(request: Request = None, oauth_token: str = None):
    """Extract token from various sources, prioritizing standard formats but 
    supporting Hugging Face Spaces custom headers"""
    
    # First try the standard OAuth2 token if provided
    if oauth_token:
        return oauth_token
        
    if request is None:
        return None
    
    # Try standard Authorization header (works in local development)
    auth_header = request.headers.get("Authorization")
    if auth_header and auth_header.startswith("Bearer "):
        return auth_header.replace("Bearer ", "")
    
    # Try Hugging Face's custom header
    hf_token = request.headers.get("x-ip-token")
    if hf_token:
        log_info(f"Using token from Hugging Face x-ip-token header")
        return hf_token
    
    # Final fallback: check query parameters
    token_param = request.query_params.get("token")
    if token_param:
        log_info(f"Using token from query parameter")
        return token_param
        
    return None

# Replace or add this function
async def get_current_user(
    request: Request,
    db: Session = Depends(get_db),
    oauth_token: str = Depends(oauth2_scheme_optional)
):
    """Enhanced user authentication that supports both standard OAuth2
    and Hugging Face Spaces deployments"""
    
    log_info("Getting current user with flexible auth")
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    
    # Get token from any available source
    token = await get_token_from_request(request, oauth_token)
    
    if not token:
        log_error("No authentication token found")
        raise credentials_exception
    
    try:
        # Try to decode the token
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        email: str = payload.get("sub")
        if email is None:
            log_error("Token missing 'sub' claim")
            raise credentials_exception
            
        token_data = TokenData(email=email)
        
    except JWTError as e:
        log_error(f"JWT verification failed: {str(e)}", e)
        raise credentials_exception
        
    except Exception as e:
        log_error(f"Token processing error: {str(e)}", e)
        raise HTTPException(status_code=500, detail=str(e))
    
    # Find the user
    user = get_user(db, email=token_data.email)
    if user is None:
        log_error(f"User not found: {token_data.email}")
        raise credentials_exception
        
    return user

# Add this function for active users with flexible auth
async def get_current_active_user(
    request: Request,
    db: Session = Depends(get_db),
    oauth_token: str = Depends(oauth2_scheme_optional)
):
    """Get active user with flexible authentication"""
    current_user = await get_current_user(request, db, oauth_token)
    
    if not current_user.is_active:
        raise HTTPException(status_code=400, detail="Inactive user")
        
    return UserResponse.from_orm(current_user)

async def get_current_user_old(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
    log_info("Getting current user")
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        email: str = payload.get("sub")
        if email is None:
            raise credentials_exception
        token_data = TokenData(email=email)
    except JWTError as e:
        log_error(f"JWT error: {str(e)}",e)
        raise credentials_exception
    except Exception as e:
        log_error(f"Error decoding token: {str(e)}",e)
        raise HTTPException(status_code=500, detail=str(e))
    user = get_user(db, email=token_data.email)
    if user is None:
        raise credentials_exception
    return user

async def get_current_active_user_old(current_user: User = Depends(get_current_user_old)):
    log_info("Getting current active user")
    try:
        if not current_user.is_active:
            raise HTTPException(status_code=400, detail="Inactive user")
        return UserResponse.from_orm(current_user)
    except Exception as e:
        log_error(f"Error getting current active user: {str(e)}",e)
        raise HTTPException(status_code=500, detail=str(e))

def create_user(db: Session, name: str, email: str, password: str):
    log_info(f"Creating user: {name}")
    try:
        hashed_password = get_password_hash(password)
        db_user = User(name=name, email=email.lower(), hashed_password=hashed_password)
        db.add(db_user)
        db.commit()
        db.refresh(db_user)
        return db_user
    except Exception as e:
        log_error(f"Error creating user: {str(e)}",e)
        raise HTTPException(status_code=500, detail=str(e))