shaheerawan3's picture
Update utils/security.py
3fcb5a8 verified
import streamlit as st
import re
from datetime import datetime, timedelta
import jwt
from prometheus_client import Counter
import pyotp
from uuid import uuid4
from typing import Tuple, Optional
import logging
from utils.metrics import get_login_attempts_counter
from prometheus_client import REGISTRY
# Initialize logging
logger = logging.getLogger(__name__)
# Use the global counter instance
login_attempts_counter = get_login_attempts_counter()
if 'login_attempts_counter' not in st.session_state:
try:
login_attempts_counter = REGISTRY.get_sample_value('login_attempts_total')
if not login_attempts_counter:
st.session_state.login_attempts_counter = Counter(
'login_attempts_total', 'Total login attempts', registry=REGISTRY
)
except ValueError:
st.session_state.login_attempts_counter = Counter(
'login_attempts_total', 'Total login attempts', registry=REGISTRY
)
# JWT Secret - Ensure secure storage in production
JWT_SECRET = "your-secret-key-here" # Replace with a secure secret in productio
def clear_session_safely():
"""Safely clear session while preserving critical components."""
try:
preserved_keys = {'db', 'client_ip'}
preserved_values = {key: st.session_state[key] for key in preserved_keys if key in st.session_state}
st.session_state.clear()
for key, value in preserved_values.items():
st.session_state[key] = value
return True
except Exception as e:
logger.error(f"Session clearing failed: {str(e)}")
return False
def clear_security_metrics():
"""Clear security metrics."""
from utils.metrics import clear_metrics
clear_metrics()
def clear_security_metrics():
"""Unregister specific security metrics."""
try:
REGISTRY.unregister(st.session_state.login_attempts_counter)
except KeyError:
pass
def validate_password(password: str) -> Tuple[bool, str]:
"""Validate password strength."""
if len(password) < 12:
return False, "Password must be at least 12 characters"
if not re.search(r"[A-Z]", password):
return False, "Must contain uppercase letters"
if not re.search(r"[a-z]", password):
return False, "Must contain lowercase letters"
if not re.search(r"\d", password):
return False, "Must contain numbers"
if not re.search(r"[!@#$%^&*]", password):
return False, "Must contain special characters"
return True, "Password valid"
def check_rate_limit(username: str, max_attempts: int = 3, window: int = 15) -> bool:
try:
cursor = st.session_state.db.cursor()
window_start = (datetime.now() - timedelta(minutes=window)).isoformat()
cursor.execute("""
SELECT COUNT(*) FROM login_attempt_history
WHERE username = ? AND attempt_time > ? AND success = FALSE
""", (username, window_start))
count = cursor.fetchone()[0]
cursor.execute("""
INSERT INTO login_attempt_history (id, username, attempt_time, ip_address, success)
VALUES (?, ?, ?, ?, ?)
""", (str(uuid4()), username, datetime.now().isoformat(),
st.session_state.get('client_ip', 'unknown'), False))
st.session_state.db.commit()
return count >= max_attempts
except Exception as e:
logger.error(f"Rate limit check failed: {str(e)}")
return False
def generate_jwt_token(user_id: str, expiry_minutes: int = 30) -> str:
"""Generate JWT token for user session."""
payload = {
'user_id': user_id,
'exp': datetime.utcnow() + timedelta(minutes=expiry_minutes),
'iat': datetime.utcnow(),
'jti': str(uuid4())
}
return jwt.encode(payload, JWT_SECRET, algorithm='HS256')
def verify_jwt_token(token: str) -> Optional[str]:
"""Verify JWT token and return user_id if valid."""
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=['HS256'])
return payload.get('user_id')
except jwt.ExpiredSignatureError:
logger.warning("Token expired")
return None
except jwt.InvalidTokenError:
logger.warning("Invalid token")
return None
def verify_totp(secret: str, code: str) -> bool:
"""Verify TOTP code."""
try:
totp = pyotp.TOTP(secret)
return totp.verify(code)
except Exception as e:
logger.error(f"TOTP verification failed: {str(e)}")
return False
def log_security_event(user_id: str, event_type: str, details: str):
"""Log security events to the audit log."""
try:
cursor = st.session_state.db.cursor()
cursor.execute("""
INSERT INTO audit_logs (id, user_id, action, timestamp, details)
VALUES (?, ?, ?, ?, ?)
""", (
str(uuid4()),
user_id,
event_type,
datetime.now().isoformat(),
details
))
st.session_state.db.commit()
except Exception as e:
logger.error(f"Failed to log security event: {str(e)}")
def clear_security_metrics():
"""Clear security metrics on logout."""
registry.clear()