Spaces:
Build error
Build error
| import streamlit as st | |
| import re | |
| from datetime import datetime, timedelta | |
| import jwt | |
| from prometheus_client import Counter | |
| import pyotp | |
| from uuid import uuid4 | |
| from typing import Tuple, Optional | |
| import logging | |
| from utils.metrics import get_login_attempts_counter | |
| from prometheus_client import REGISTRY | |
| # Initialize logging | |
| logger = logging.getLogger(__name__) | |
| # Use the global counter instance | |
| login_attempts_counter = get_login_attempts_counter() | |
| if 'login_attempts_counter' not in st.session_state: | |
| try: | |
| login_attempts_counter = REGISTRY.get_sample_value('login_attempts_total') | |
| if not login_attempts_counter: | |
| st.session_state.login_attempts_counter = Counter( | |
| 'login_attempts_total', 'Total login attempts', registry=REGISTRY | |
| ) | |
| except ValueError: | |
| st.session_state.login_attempts_counter = Counter( | |
| 'login_attempts_total', 'Total login attempts', registry=REGISTRY | |
| ) | |
| # JWT Secret - Ensure secure storage in production | |
| JWT_SECRET = "your-secret-key-here" # Replace with a secure secret in productio | |
| def clear_session_safely(): | |
| """Safely clear session while preserving critical components.""" | |
| try: | |
| preserved_keys = {'db', 'client_ip'} | |
| preserved_values = {key: st.session_state[key] for key in preserved_keys if key in st.session_state} | |
| st.session_state.clear() | |
| for key, value in preserved_values.items(): | |
| st.session_state[key] = value | |
| return True | |
| except Exception as e: | |
| logger.error(f"Session clearing failed: {str(e)}") | |
| return False | |
| def clear_security_metrics(): | |
| """Clear security metrics.""" | |
| from utils.metrics import clear_metrics | |
| clear_metrics() | |
| def clear_security_metrics(): | |
| """Unregister specific security metrics.""" | |
| try: | |
| REGISTRY.unregister(st.session_state.login_attempts_counter) | |
| except KeyError: | |
| pass | |
| def validate_password(password: str) -> Tuple[bool, str]: | |
| """Validate password strength.""" | |
| if len(password) < 12: | |
| return False, "Password must be at least 12 characters" | |
| if not re.search(r"[A-Z]", password): | |
| return False, "Must contain uppercase letters" | |
| if not re.search(r"[a-z]", password): | |
| return False, "Must contain lowercase letters" | |
| if not re.search(r"\d", password): | |
| return False, "Must contain numbers" | |
| if not re.search(r"[!@#$%^&*]", password): | |
| return False, "Must contain special characters" | |
| return True, "Password valid" | |
| def check_rate_limit(username: str, max_attempts: int = 3, window: int = 15) -> bool: | |
| try: | |
| cursor = st.session_state.db.cursor() | |
| window_start = (datetime.now() - timedelta(minutes=window)).isoformat() | |
| cursor.execute(""" | |
| SELECT COUNT(*) FROM login_attempt_history | |
| WHERE username = ? AND attempt_time > ? AND success = FALSE | |
| """, (username, window_start)) | |
| count = cursor.fetchone()[0] | |
| cursor.execute(""" | |
| INSERT INTO login_attempt_history (id, username, attempt_time, ip_address, success) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (str(uuid4()), username, datetime.now().isoformat(), | |
| st.session_state.get('client_ip', 'unknown'), False)) | |
| st.session_state.db.commit() | |
| return count >= max_attempts | |
| except Exception as e: | |
| logger.error(f"Rate limit check failed: {str(e)}") | |
| return False | |
| def generate_jwt_token(user_id: str, expiry_minutes: int = 30) -> str: | |
| """Generate JWT token for user session.""" | |
| payload = { | |
| 'user_id': user_id, | |
| 'exp': datetime.utcnow() + timedelta(minutes=expiry_minutes), | |
| 'iat': datetime.utcnow(), | |
| 'jti': str(uuid4()) | |
| } | |
| return jwt.encode(payload, JWT_SECRET, algorithm='HS256') | |
| def verify_jwt_token(token: str) -> Optional[str]: | |
| """Verify JWT token and return user_id if valid.""" | |
| try: | |
| payload = jwt.decode(token, JWT_SECRET, algorithms=['HS256']) | |
| return payload.get('user_id') | |
| except jwt.ExpiredSignatureError: | |
| logger.warning("Token expired") | |
| return None | |
| except jwt.InvalidTokenError: | |
| logger.warning("Invalid token") | |
| return None | |
| def verify_totp(secret: str, code: str) -> bool: | |
| """Verify TOTP code.""" | |
| try: | |
| totp = pyotp.TOTP(secret) | |
| return totp.verify(code) | |
| except Exception as e: | |
| logger.error(f"TOTP verification failed: {str(e)}") | |
| return False | |
| def log_security_event(user_id: str, event_type: str, details: str): | |
| """Log security events to the audit log.""" | |
| try: | |
| cursor = st.session_state.db.cursor() | |
| cursor.execute(""" | |
| INSERT INTO audit_logs (id, user_id, action, timestamp, details) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, ( | |
| str(uuid4()), | |
| user_id, | |
| event_type, | |
| datetime.now().isoformat(), | |
| details | |
| )) | |
| st.session_state.db.commit() | |
| except Exception as e: | |
| logger.error(f"Failed to log security event: {str(e)}") | |
| def clear_security_metrics(): | |
| """Clear security metrics on logout.""" | |
| registry.clear() |