Spaces:
Sleeping
Sleeping
| """ | |
| User Authentication Module for RAG Personal Diary Chatbot | |
| Developed by huytrao | |
| This module handles user registration, login, session management | |
| and user-specific data isolation. | |
| """ | |
| import sqlite3 | |
| import hashlib | |
| import secrets | |
| import streamlit as st | |
| from datetime import datetime, timedelta | |
| from typing import Optional, Dict, Any | |
| import os | |
| import re | |
| class UserAuthManager: | |
| """ | |
| Handles user authentication and session management for the diary app. | |
| """ | |
| def __init__(self, db_path: str = "auth.db"): | |
| """ | |
| Initialize the authentication manager. | |
| Args: | |
| db_path: Path to the authentication database | |
| """ | |
| self.db_path = db_path | |
| self.session_timeout = timedelta(hours=24) # 24 hour session timeout | |
| self._init_auth_database() | |
| def _init_auth_database(self): | |
| """Initialize the authentication database tables.""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| # Users table | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS users ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT UNIQUE NOT NULL, | |
| email TEXT UNIQUE NOT NULL, | |
| password_hash TEXT NOT NULL, | |
| salt TEXT NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| last_login TIMESTAMP, | |
| is_active BOOLEAN DEFAULT 1, | |
| profile_data TEXT DEFAULT '{}' | |
| ) | |
| ''') | |
| # Sessions table for user session management | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS user_sessions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER NOT NULL, | |
| session_token TEXT UNIQUE NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| expires_at TIMESTAMP NOT NULL, | |
| is_active BOOLEAN DEFAULT 1, | |
| FOREIGN KEY (user_id) REFERENCES users (id) | |
| ) | |
| ''') | |
| # Create indexes for better performance | |
| cursor.execute('CREATE INDEX IF NOT EXISTS idx_sessions_token ON user_sessions(session_token)') | |
| cursor.execute('CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON user_sessions(user_id)') | |
| cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)') | |
| cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)') | |
| conn.commit() | |
| conn.close() | |
| def _hash_password(self, password: str, salt: str = None) -> tuple: | |
| """ | |
| Hash a password with salt using PBKDF2. | |
| Args: | |
| password: The password to hash | |
| salt: Optional salt (will generate new one if not provided) | |
| Returns: | |
| Tuple of (hashed_password, salt) | |
| """ | |
| if salt is None: | |
| salt = secrets.token_hex(32) | |
| # Use PBKDF2 with SHA-256 | |
| password_hash = hashlib.pbkdf2_hmac( | |
| 'sha256', | |
| password.encode('utf-8'), | |
| salt.encode('utf-8'), | |
| 100000 # 100,000 iterations | |
| ) | |
| return password_hash.hex(), salt | |
| def _generate_session_token(self) -> str: | |
| """Generate a secure session token.""" | |
| return secrets.token_urlsafe(32) | |
| def register_user(self, username: str, email: str, password: str) -> Dict[str, Any]: | |
| """ | |
| Register a new user. | |
| Args: | |
| username: Username (3-20 characters, alphanumeric + underscore) | |
| email: Email address | |
| password: Password (min 8 characters) | |
| Returns: | |
| Dictionary with success status and message | |
| """ | |
| # Validation | |
| if not self._validate_username(username): | |
| return {"success": False, "message": "Username must be 3-20 characters, alphanumeric and underscore only"} | |
| if not self._validate_email(email): | |
| return {"success": False, "message": "Invalid email format"} | |
| if not self._validate_password(password): | |
| return {"success": False, "message": "Password must be at least 8 characters"} | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| # Check if username or email already exists | |
| cursor.execute('SELECT id FROM users WHERE username = ? OR email = ?', (username, email)) | |
| if cursor.fetchone(): | |
| return {"success": False, "message": "Username or email already exists"} | |
| # Hash password | |
| password_hash, salt = self._hash_password(password) | |
| # Insert new user | |
| cursor.execute(''' | |
| INSERT INTO users (username, email, password_hash, salt) | |
| VALUES (?, ?, ?, ?) | |
| ''', (username, email, password_hash, salt)) | |
| user_id = cursor.lastrowid | |
| conn.commit() | |
| conn.close() | |
| return { | |
| "success": True, | |
| "message": "User registered successfully", | |
| "user_id": user_id | |
| } | |
| except sqlite3.Error as e: | |
| return {"success": False, "message": f"Database error: {str(e)}"} | |
| def login_user(self, username: str, password: str) -> Dict[str, Any]: | |
| """ | |
| Authenticate user and create session. | |
| Args: | |
| username: Username or email | |
| password: Password | |
| Returns: | |
| Dictionary with success status, message, and session data | |
| """ | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| # Get user by username or email | |
| cursor.execute(''' | |
| SELECT id, username, email, password_hash, salt, is_active | |
| FROM users | |
| WHERE (username = ? OR email = ?) AND is_active = 1 | |
| ''', (username, username)) | |
| user = cursor.fetchone() | |
| if not user: | |
| return {"success": False, "message": "Invalid credentials"} | |
| user_id, user_username, user_email, stored_hash, salt, is_active = user | |
| # Verify password | |
| password_hash, _ = self._hash_password(password, salt) | |
| if password_hash != stored_hash: | |
| return {"success": False, "message": "Invalid credentials"} | |
| # Create session | |
| session_token = self._generate_session_token() | |
| expires_at = datetime.now() + self.session_timeout | |
| # Clean up old sessions for this user | |
| cursor.execute('UPDATE user_sessions SET is_active = 0 WHERE user_id = ?', (user_id,)) | |
| # Insert new session | |
| cursor.execute(''' | |
| INSERT INTO user_sessions (user_id, session_token, expires_at) | |
| VALUES (?, ?, ?) | |
| ''', (user_id, session_token, expires_at)) | |
| # Update last login | |
| cursor.execute('UPDATE users SET last_login = CURRENT_TIMESTAMP WHERE id = ?', (user_id,)) | |
| conn.commit() | |
| conn.close() | |
| return { | |
| "success": True, | |
| "message": "Login successful", | |
| "session_token": session_token, | |
| "user": { | |
| "id": user_id, | |
| "username": user_username, | |
| "email": user_email | |
| } | |
| } | |
| except sqlite3.Error as e: | |
| return {"success": False, "message": f"Database error: {str(e)}"} | |
| def validate_session(self, session_token: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Validate a session token and return user data. | |
| Args: | |
| session_token: The session token to validate | |
| Returns: | |
| User data if valid session, None otherwise | |
| """ | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| SELECT u.id, u.username, u.email, s.expires_at | |
| FROM user_sessions s | |
| JOIN users u ON s.user_id = u.id | |
| WHERE s.session_token = ? AND s.is_active = 1 AND u.is_active = 1 | |
| ''', (session_token,)) | |
| result = cursor.fetchone() | |
| if not result: | |
| return None | |
| user_id, username, email, expires_at = result | |
| # Check if session has expired | |
| expires_datetime = datetime.fromisoformat(expires_at) | |
| if datetime.now() > expires_datetime: | |
| # Deactivate expired session | |
| cursor.execute('UPDATE user_sessions SET is_active = 0 WHERE session_token = ?', (session_token,)) | |
| conn.commit() | |
| conn.close() | |
| return None | |
| conn.close() | |
| return { | |
| "id": user_id, | |
| "username": username, | |
| "email": email | |
| } | |
| except sqlite3.Error: | |
| return None | |
| def logout_user(self, session_token: str) -> bool: | |
| """ | |
| Logout user by deactivating session. | |
| Args: | |
| session_token: The session token to logout | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute('UPDATE user_sessions SET is_active = 0 WHERE session_token = ?', (session_token,)) | |
| conn.commit() | |
| conn.close() | |
| return True | |
| except sqlite3.Error: | |
| return False | |
| def change_password(self, user_id: int, current_password: str, new_password: str) -> Dict[str, Any]: | |
| """ | |
| Change user password. | |
| Args: | |
| user_id: User ID | |
| current_password: Current password | |
| new_password: New password | |
| Returns: | |
| Dictionary with success status and message | |
| """ | |
| if not self._validate_password(new_password): | |
| return {"success": False, "message": "New password must be at least 8 characters"} | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| # Get current password hash | |
| cursor.execute('SELECT password_hash, salt FROM users WHERE id = ?', (user_id,)) | |
| result = cursor.fetchone() | |
| if not result: | |
| return {"success": False, "message": "User not found"} | |
| stored_hash, salt = result | |
| # Verify current password | |
| current_hash, _ = self._hash_password(current_password, salt) | |
| if current_hash != stored_hash: | |
| return {"success": False, "message": "Current password is incorrect"} | |
| # Hash new password | |
| new_hash, new_salt = self._hash_password(new_password) | |
| # Update password | |
| cursor.execute('UPDATE users SET password_hash = ?, salt = ? WHERE id = ?', | |
| (new_hash, new_salt, user_id)) | |
| # Deactivate all sessions to force re-login | |
| cursor.execute('UPDATE user_sessions SET is_active = 0 WHERE user_id = ?', (user_id,)) | |
| conn.commit() | |
| conn.close() | |
| return {"success": True, "message": "Password changed successfully"} | |
| except sqlite3.Error as e: | |
| return {"success": False, "message": f"Database error: {str(e)}"} | |
| def get_user_profile(self, user_id: int) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get user profile data. | |
| Args: | |
| user_id: User ID | |
| Returns: | |
| User profile data or None | |
| """ | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| SELECT username, email, created_at, last_login, profile_data | |
| FROM users | |
| WHERE id = ? AND is_active = 1 | |
| ''', (user_id,)) | |
| result = cursor.fetchone() | |
| if not result: | |
| return None | |
| username, email, created_at, last_login, profile_data = result | |
| conn.close() | |
| return { | |
| "id": user_id, | |
| "username": username, | |
| "email": email, | |
| "created_at": created_at, | |
| "last_login": last_login, | |
| "profile_data": profile_data | |
| } | |
| except sqlite3.Error: | |
| return None | |
| def _validate_username(self, username: str) -> bool: | |
| """Validate username format.""" | |
| if not username or len(username) < 3 or len(username) > 20: | |
| return False | |
| return re.match(r'^[a-zA-Z0-9_]+$', username) is not None | |
| def _validate_email(self, email: str) -> bool: | |
| """Validate email format.""" | |
| pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' | |
| return re.match(pattern, email) is not None | |
| def _validate_password(self, password: str) -> bool: | |
| """Validate password format.""" | |
| return len(password) >= 8 | |
| def cleanup_expired_sessions(self): | |
| """Clean up expired sessions from the database.""" | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute('UPDATE user_sessions SET is_active = 0 WHERE expires_at < ?', (datetime.now(),)) | |
| conn.commit() | |
| conn.close() | |
| except sqlite3.Error: | |
| pass | |