Prathamesh Sable commited on
Commit
e631be8
·
1 Parent(s): 364330f

made to use HF auth

Browse files
Files changed (1) hide show
  1. services/auth_service.py +100 -5
services/auth_service.py CHANGED
@@ -1,8 +1,8 @@
1
  from passlib.context import CryptContext
2
  from jose import JWTError, jwt
3
  from datetime import datetime, timedelta
4
- from fastapi import Depends, HTTPException, status
5
- from fastapi.security import OAuth2PasswordBearer
6
  from sqlalchemy import func
7
  from sqlalchemy.orm import Session,Mapped
8
  from db.database import get_db
@@ -19,7 +19,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
19
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
20
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
21
 
22
-
 
23
 
24
  def verify_password(plain_password, hashed_password):
25
  log_info("Verifying password")
@@ -63,7 +64,101 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
63
  encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
64
  return encoded_jwt
65
 
66
- async def get_current_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  log_info("Getting current user")
68
  credentials_exception = HTTPException(
69
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -87,7 +182,7 @@ async def get_current_user(db: Session = Depends(get_db), token: str = Depends(o
87
  raise credentials_exception
88
  return user
89
 
90
- async def get_current_active_user(current_user: User = Depends(get_current_user)):
91
  log_info("Getting current active user")
92
  try:
93
  if not current_user.is_active:
 
1
  from passlib.context import CryptContext
2
  from jose import JWTError, jwt
3
  from datetime import datetime, timedelta
4
+ from fastapi import Depends, HTTPException, status, Request
5
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
6
  from sqlalchemy import func
7
  from sqlalchemy.orm import Session,Mapped
8
  from db.database import get_db
 
19
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
20
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
21
 
22
+ # Create an optional OAuth2 scheme that doesn't auto-error
23
+ oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
24
 
25
  def verify_password(plain_password, hashed_password):
26
  log_info("Verifying password")
 
64
  encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
65
  return encoded_jwt
66
 
67
+ # New flexible token extractor
68
+ async def get_token_from_request(request: Request = None, oauth_token: str = None):
69
+ """Extract token from various sources, prioritizing standard formats but
70
+ supporting Hugging Face Spaces custom headers"""
71
+
72
+ # First try the standard OAuth2 token if provided
73
+ if oauth_token:
74
+ return oauth_token
75
+
76
+ if request is None:
77
+ return None
78
+
79
+ # Try standard Authorization header (works in local development)
80
+ auth_header = request.headers.get("Authorization")
81
+ if auth_header and auth_header.startswith("Bearer "):
82
+ return auth_header.replace("Bearer ", "")
83
+
84
+ # Try Hugging Face's custom header
85
+ hf_token = request.headers.get("x-ip-token")
86
+ if hf_token:
87
+ log_info(f"Using token from Hugging Face x-ip-token header")
88
+ return hf_token
89
+
90
+ # Final fallback: check query parameters
91
+ token_param = request.query_params.get("token")
92
+ if token_param:
93
+ log_info(f"Using token from query parameter")
94
+ return token_param
95
+
96
+ return None
97
+
98
+ # Replace or add this function
99
+ async def get_current_user(
100
+ request: Request,
101
+ db: Session = Depends(get_db),
102
+ oauth_token: str = Depends(oauth2_scheme_optional)
103
+ ):
104
+ """Enhanced user authentication that supports both standard OAuth2
105
+ and Hugging Face Spaces deployments"""
106
+
107
+ log_info("Getting current user with flexible auth")
108
+ credentials_exception = HTTPException(
109
+ status_code=status.HTTP_401_UNAUTHORIZED,
110
+ detail="Could not validate credentials",
111
+ headers={"WWW-Authenticate": "Bearer"},
112
+ )
113
+
114
+ # Get token from any available source
115
+ token = await get_token_from_request(request, oauth_token)
116
+
117
+ if not token:
118
+ log_error("No authentication token found")
119
+ raise credentials_exception
120
+
121
+ try:
122
+ # Try to decode the token
123
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
124
+ email: str = payload.get("sub")
125
+ if email is None:
126
+ log_error("Token missing 'sub' claim")
127
+ raise credentials_exception
128
+
129
+ token_data = TokenData(email=email)
130
+
131
+ except JWTError as e:
132
+ log_error(f"JWT verification failed: {str(e)}", e)
133
+ raise credentials_exception
134
+
135
+ except Exception as e:
136
+ log_error(f"Token processing error: {str(e)}", e)
137
+ raise HTTPException(status_code=500, detail=str(e))
138
+
139
+ # Find the user
140
+ user = get_user(db, email=token_data.email)
141
+ if user is None:
142
+ log_error(f"User not found: {token_data.email}")
143
+ raise credentials_exception
144
+
145
+ return user
146
+
147
+ # Add this function for active users with flexible auth
148
+ async def get_current_active_user(
149
+ request: Request,
150
+ db: Session = Depends(get_db),
151
+ oauth_token: str = Depends(oauth2_scheme_optional)
152
+ ):
153
+ """Get active user with flexible authentication"""
154
+ current_user = await get_current_user(request, db, oauth_token)
155
+
156
+ if not current_user.is_active:
157
+ raise HTTPException(status_code=400, detail="Inactive user")
158
+
159
+ return UserResponse.from_orm(current_user)
160
+
161
+ async def get_current_user_old(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
162
  log_info("Getting current user")
163
  credentials_exception = HTTPException(
164
  status_code=status.HTTP_401_UNAUTHORIZED,
 
182
  raise credentials_exception
183
  return user
184
 
185
+ async def get_current_active_user_old(current_user: User = Depends(get_current_user_old)):
186
  log_info("Getting current active user")
187
  try:
188
  if not current_user.is_active: