Tahasaif3 commited on
Commit
baef2a6
·
verified ·
1 Parent(s): 1cf5def

Update src/utils/security.py

Browse files
Files changed (1) hide show
  1. src/utils/security.py +24 -90
src/utils/security.py CHANGED
@@ -2,122 +2,56 @@ from passlib.context import CryptContext
2
  from datetime import datetime, timedelta
3
  from typing import Optional, Union
4
  import uuid
5
- import logging
6
  from jose import JWTError, jwt
7
  from ..config import settings
8
 
9
- # Configure logger
10
- logger = logging.getLogger(__name__)
11
- logger.setLevel(logging.DEBUG)
12
-
13
- # Fallback values for JWT settings
14
- _FALLBACK_JWT_SECRET_KEY = "fallback_secret_key_for_development_only"
15
- _FALLBACK_JWT_ALGORITHM = "HS256"
16
- _FALLBACK_ACCESS_TOKEN_EXPIRE_DAYS = 7
17
-
18
  # Password hashing context
19
- # Handle bcrypt backend issues by specifying a fallback
20
- try:
21
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
22
- except Exception as e:
23
- logger.warning(f"Failed to initialize bcrypt context: {e}, using plaintext (NOT FOR PRODUCTION)")
24
- pwd_context = CryptContext(schemes=["plaintext"], deprecated="auto")
25
 
26
 
27
  def hash_password(password: str) -> str:
28
- """Hash a password using bcrypt or fallback method."""
29
- try:
30
- # Truncate password to 72 bytes to avoid bcrypt limitation
31
- if len(password.encode('utf-8')) > 72:
32
- logger.warning("Password exceeds 72 bytes, truncating")
33
- # Properly truncate UTF-8 bytes
34
- password_bytes = password.encode('utf-8')[:72]
35
- password = password_bytes.decode('utf-8', errors='ignore')
36
- return pwd_context.hash(password)
37
- except ValueError as ve:
38
- if "72 bytes" in str(ve):
39
- logger.error(f"Password too long even after truncation: {str(ve)}")
40
- # Force truncate to exactly 72 bytes and try again
41
- password = password.encode('utf-8')[:72].decode('utf-8', errors='ignore')
42
- return pwd_context.hash(password)
43
- else:
44
- logger.error(f"ValueError hashing password: {str(ve)}")
45
- raise
46
- except Exception as e:
47
- logger.error(f"Error hashing password: {str(e)}")
48
- raise
49
 
50
 
51
  def verify_password(plain_password: str, hashed_password: str) -> bool:
52
  """Verify a plain password against its hash."""
53
- try:
54
- # Truncate password to 72 bytes to match hashing behavior
55
- if len(plain_password.encode('utf-8')) > 72:
56
- logger.warning("Password exceeds 72 bytes during verification, truncating")
57
- plain_password = plain_password.encode('utf-8')[:72].decode('utf-8', errors='ignore')
58
- return pwd_context.verify(plain_password, hashed_password)
59
- except Exception as e:
60
- logger.error(f"Error verifying password: {str(e)}")
61
- return False
62
 
63
 
64
  def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
65
  """Create a JWT access token."""
66
- try:
67
- to_encode = data.copy()
68
 
69
- if expires_delta:
70
- expire = datetime.utcnow() + expires_delta
71
- else:
72
- # Default to 7 days if no expiration is provided
73
- expire = datetime.utcnow() + timedelta(days=getattr(settings, 'ACCESS_TOKEN_EXPIRE_DAYS', _FALLBACK_ACCESS_TOKEN_EXPIRE_DAYS))
74
 
75
- to_encode.update({"exp": expire, "iat": datetime.utcnow()})
76
-
77
- # Use fallback values if settings are not properly configured
78
- secret_key = getattr(settings, 'JWT_SECRET_KEY', _FALLBACK_JWT_SECRET_KEY)
79
- algorithm = getattr(settings, 'JWT_ALGORITHM', _FALLBACK_JWT_ALGORITHM)
80
-
81
- logger.debug(f"Creating token with settings: SECRET_KEY={secret_key}, ALGORITHM={algorithm}")
82
 
83
- encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
84
- return encoded_jwt
85
- except Exception as e:
86
- logger.error(f"Error creating access token: {str(e)}")
87
- raise
88
 
89
 
90
  def verify_token(token: str) -> Optional[dict]:
91
  """Verify a JWT token and return the payload if valid."""
92
  try:
93
- # Use fallback values if settings are not properly configured
94
- secret_key = getattr(settings, 'JWT_SECRET_KEY', _FALLBACK_JWT_SECRET_KEY)
95
- algorithm = getattr(settings, 'JWT_ALGORITHM', _FALLBACK_JWT_ALGORITHM)
96
-
97
- logger.debug(f"Verifying token with settings: SECRET_KEY={secret_key}, ALGORITHM={algorithm}")
98
- payload = jwt.decode(token, secret_key, algorithms=[algorithm])
99
  return payload
100
- except JWTError as e:
101
- logger.warning(f"JWT Error during token verification: {str(e)}")
102
- return None
103
- except Exception as e:
104
- logger.error(f"Unexpected error during token verification: {str(e)}")
105
  return None
106
 
107
 
108
  def verify_user_id_from_token(token: str) -> Optional[uuid.UUID]:
109
  """Extract user_id from JWT token."""
110
- try:
111
- payload = verify_token(token)
112
- if payload:
113
- user_id_str = payload.get("sub")
114
- if user_id_str:
115
- try:
116
- return uuid.UUID(user_id_str)
117
- except ValueError as e:
118
- logger.warning(f"Invalid UUID format in token payload: {user_id_str} - {str(e)}")
119
- return None
120
- return None
121
- except Exception as e:
122
- logger.error(f"Error extracting user ID from token: {str(e)}")
123
- return None
 
2
  from datetime import datetime, timedelta
3
  from typing import Optional, Union
4
  import uuid
 
5
  from jose import JWTError, jwt
6
  from ..config import settings
7
 
 
 
 
 
 
 
 
 
 
8
  # Password hashing context
9
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 
 
 
 
 
10
 
11
 
12
  def hash_password(password: str) -> str:
13
+ """Hash a password using bcrypt."""
14
+ return pwd_context.hash(password)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  def verify_password(plain_password: str, hashed_password: str) -> bool:
18
  """Verify a plain password against its hash."""
19
+ return pwd_context.verify(plain_password, hashed_password)
 
 
 
 
 
 
 
 
20
 
21
 
22
  def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
23
  """Create a JWT access token."""
24
+ to_encode = data.copy()
 
25
 
26
+ if expires_delta:
27
+ expire = datetime.utcnow() + expires_delta
28
+ else:
29
+ # Default to 7 days if no expiration is provided
30
+ expire = datetime.utcnow() + timedelta(days=settings.ACCESS_TOKEN_EXPIRE_DAYS)
31
 
32
+ to_encode.update({"exp": expire, "iat": datetime.utcnow()})
 
 
 
 
 
 
33
 
34
+ encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
35
+ return encoded_jwt
 
 
 
36
 
37
 
38
  def verify_token(token: str) -> Optional[dict]:
39
  """Verify a JWT token and return the payload if valid."""
40
  try:
41
+ payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
 
 
 
 
 
42
  return payload
43
+ except JWTError:
 
 
 
 
44
  return None
45
 
46
 
47
  def verify_user_id_from_token(token: str) -> Optional[uuid.UUID]:
48
  """Extract user_id from JWT token."""
49
+ payload = verify_token(token)
50
+ if payload:
51
+ user_id_str = payload.get("sub")
52
+ if user_id_str:
53
+ try:
54
+ return uuid.UUID(user_id_str)
55
+ except ValueError:
56
+ return None
57
+ return None