File size: 2,654 Bytes
02898ce
34e27fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02898ce
 
 
34e27fb
 
 
 
 
02898ce
 
 
 
 
 
 
34e27fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import bcrypt
from datetime import datetime, timedelta
from typing import Optional, Union
import uuid
import hashlib
from jose import JWTError, jwt
from ..config import settings


def _prepare_password(password: str) -> str:
    """
    Prepare password for bcrypt by hashing if longer than 72 bytes.
    This avoids truncation and loss of information.
    """
    password_bytes = password.encode('utf-8')
    if len(password_bytes) > 72:
        # Hash long passwords with SHA256 first to bring them within bcrypt's limit
        password_hash = hashlib.sha256(password_bytes).hexdigest()
        return password_hash
    return password


def hash_password(password: str) -> str:
    """Hash a password using bcrypt, handling long passwords via SHA256 pre-hashing."""
    prepared = _prepare_password(password)
    password_bytes = prepared.encode('utf-8')
    salt = bcrypt.gensalt(rounds=12)
    return bcrypt.hashpw(password_bytes, salt).decode('utf-8')


def verify_password(plain_password: str, hashed_password: str) -> bool:
    """Verify a plain password against its hash, handling long passwords via SHA256 pre-hashing."""
    prepared = _prepare_password(plain_password)
    password_bytes = prepared.encode('utf-8')
    hashed_bytes = hashed_password.encode('utf-8')
    try:
        return bcrypt.checkpw(password_bytes, hashed_bytes)
    except (ValueError, TypeError):
        # Hash format is invalid
        return False


def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
    """Create a JWT access token."""
    to_encode = data.copy()

    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        # Default to 7 days if no expiration is provided
        expire = datetime.utcnow() + timedelta(days=settings.ACCESS_TOKEN_EXPIRE_DAYS)

    to_encode.update({"exp": expire, "iat": datetime.utcnow()})

    encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
    return encoded_jwt


def verify_token(token: str) -> Optional[dict]:
    """Verify a JWT token and return the payload if valid."""
    try:
        payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
        return payload
    except JWTError:
        return None


def verify_user_id_from_token(token: str) -> Optional[uuid.UUID]:
    """Extract user_id from JWT token."""
    payload = verify_token(token)
    if payload:
        user_id_str = payload.get("sub")
        if user_id_str:
            try:
                return uuid.UUID(user_id_str)
            except ValueError:
                return None
    return None