import streamlit as st import re from datetime import datetime, timedelta import jwt from prometheus_client import Counter import pyotp from uuid import uuid4 from typing import Tuple, Optional import logging from utils.metrics import get_login_attempts_counter from prometheus_client import REGISTRY # Initialize logging logger = logging.getLogger(__name__) # Use the global counter instance login_attempts_counter = get_login_attempts_counter() if 'login_attempts_counter' not in st.session_state: try: login_attempts_counter = REGISTRY.get_sample_value('login_attempts_total') if not login_attempts_counter: st.session_state.login_attempts_counter = Counter( 'login_attempts_total', 'Total login attempts', registry=REGISTRY ) except ValueError: st.session_state.login_attempts_counter = Counter( 'login_attempts_total', 'Total login attempts', registry=REGISTRY ) # JWT Secret - Ensure secure storage in production JWT_SECRET = "your-secret-key-here" # Replace with a secure secret in productio def clear_session_safely(): """Safely clear session while preserving critical components.""" try: preserved_keys = {'db', 'client_ip'} preserved_values = {key: st.session_state[key] for key in preserved_keys if key in st.session_state} st.session_state.clear() for key, value in preserved_values.items(): st.session_state[key] = value return True except Exception as e: logger.error(f"Session clearing failed: {str(e)}") return False def clear_security_metrics(): """Clear security metrics.""" from utils.metrics import clear_metrics clear_metrics() def clear_security_metrics(): """Unregister specific security metrics.""" try: REGISTRY.unregister(st.session_state.login_attempts_counter) except KeyError: pass def validate_password(password: str) -> Tuple[bool, str]: """Validate password strength.""" if len(password) < 12: return False, "Password must be at least 12 characters" if not re.search(r"[A-Z]", password): return False, "Must contain uppercase letters" if not re.search(r"[a-z]", password): return False, "Must contain lowercase letters" if not re.search(r"\d", password): return False, "Must contain numbers" if not re.search(r"[!@#$%^&*]", password): return False, "Must contain special characters" return True, "Password valid" def check_rate_limit(username: str, max_attempts: int = 3, window: int = 15) -> bool: try: cursor = st.session_state.db.cursor() window_start = (datetime.now() - timedelta(minutes=window)).isoformat() cursor.execute(""" SELECT COUNT(*) FROM login_attempt_history WHERE username = ? AND attempt_time > ? AND success = FALSE """, (username, window_start)) count = cursor.fetchone()[0] cursor.execute(""" INSERT INTO login_attempt_history (id, username, attempt_time, ip_address, success) VALUES (?, ?, ?, ?, ?) """, (str(uuid4()), username, datetime.now().isoformat(), st.session_state.get('client_ip', 'unknown'), False)) st.session_state.db.commit() return count >= max_attempts except Exception as e: logger.error(f"Rate limit check failed: {str(e)}") return False def generate_jwt_token(user_id: str, expiry_minutes: int = 30) -> str: """Generate JWT token for user session.""" payload = { 'user_id': user_id, 'exp': datetime.utcnow() + timedelta(minutes=expiry_minutes), 'iat': datetime.utcnow(), 'jti': str(uuid4()) } return jwt.encode(payload, JWT_SECRET, algorithm='HS256') def verify_jwt_token(token: str) -> Optional[str]: """Verify JWT token and return user_id if valid.""" try: payload = jwt.decode(token, JWT_SECRET, algorithms=['HS256']) return payload.get('user_id') except jwt.ExpiredSignatureError: logger.warning("Token expired") return None except jwt.InvalidTokenError: logger.warning("Invalid token") return None def verify_totp(secret: str, code: str) -> bool: """Verify TOTP code.""" try: totp = pyotp.TOTP(secret) return totp.verify(code) except Exception as e: logger.error(f"TOTP verification failed: {str(e)}") return False def log_security_event(user_id: str, event_type: str, details: str): """Log security events to the audit log.""" try: cursor = st.session_state.db.cursor() cursor.execute(""" INSERT INTO audit_logs (id, user_id, action, timestamp, details) VALUES (?, ?, ?, ?, ?) """, ( str(uuid4()), user_id, event_type, datetime.now().isoformat(), details )) st.session_state.db.commit() except Exception as e: logger.error(f"Failed to log security event: {str(e)}") def clear_security_metrics(): """Clear security metrics on logout.""" registry.clear()