auth / adaptiveauth /risk /monitor.py
Prasannata's picture
first commit
7d369c8
"""
AdaptiveAuth Session Monitor
Continuous session verification and monitoring.
"""
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from sqlalchemy import and_
from ..models import (
User, UserSession, RiskEvent, AnomalyPattern, LoginAttempt,
SessionStatus, RiskLevel
)
from ..config import get_settings
from .engine import RiskEngine, RiskAssessment
class SessionMonitor:
"""
Continuous session monitoring for adaptive authentication.
Tracks session activity and triggers step-up auth when needed.
"""
def __init__(self, db: Session):
self.db = db
self.settings = get_settings()
def create_session(
self,
user: User,
context: Dict[str, Any],
risk_assessment: RiskAssessment,
token: str,
expires_at: datetime
) -> UserSession:
"""Create a new user session after successful login."""
# Check and enforce max concurrent sessions
self._enforce_session_limit(user)
session = UserSession(
user_id=user.id,
session_token=token,
ip_address=context.get('ip_address', ''),
user_agent=context.get('user_agent', ''),
device_fingerprint=context.get('device_fingerprint'),
country=context.get('country'),
city=context.get('city'),
current_risk_score=risk_assessment.risk_score,
current_risk_level=risk_assessment.risk_level.value,
status=SessionStatus.ACTIVE.value,
step_up_completed=risk_assessment.security_level <= 1,
last_activity=datetime.utcnow(),
created_at=datetime.utcnow(),
expires_at=expires_at
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session
def _enforce_session_limit(self, user: User):
"""Revoke oldest sessions if limit exceeded."""
active_sessions = self.db.query(UserSession).filter(
UserSession.user_id == user.id,
UserSession.status == SessionStatus.ACTIVE.value
).order_by(UserSession.created_at.asc()).all()
max_sessions = self.settings.MAX_CONCURRENT_SESSIONS
if len(active_sessions) >= max_sessions:
# Revoke oldest sessions
sessions_to_revoke = active_sessions[:len(active_sessions) - max_sessions + 1]
for session in sessions_to_revoke:
session.status = SessionStatus.REVOKED.value
self.db.commit()
def verify_session(
self,
session: UserSession,
context: Dict[str, Any]
) -> Dict[str, Any]:
"""
Verify ongoing session and check for anomalies.
Called periodically during active sessions.
"""
result = {
'valid': True,
'step_up_required': False,
'reason': None
}
# Check if session is still active
if session.status != SessionStatus.ACTIVE.value:
result['valid'] = False
result['reason'] = 'Session is no longer active'
return result
# Check expiration
if session.expires_at < datetime.utcnow():
session.status = SessionStatus.EXPIRED.value
self.db.commit()
result['valid'] = False
result['reason'] = 'Session has expired'
return result
# Check for context changes (IP, device)
context_changed = self._check_context_change(session, context)
if context_changed:
# Re-evaluate risk
user = self.db.query(User).filter(User.id == session.user_id).first()
if user:
risk_engine = RiskEngine(self.db)
assessment = risk_engine.evaluate_risk(user, context)
# Update session risk
session.current_risk_score = assessment.risk_score
session.current_risk_level = assessment.risk_level.value
if assessment.security_level >= 3:
session.status = SessionStatus.SUSPICIOUS.value
result['step_up_required'] = True
result['reason'] = 'Session context changed significantly'
self.db.commit()
# Update last activity
session.last_activity = datetime.utcnow()
session.activity_count += 1
self.db.commit()
return result
def _check_context_change(
self,
session: UserSession,
context: Dict[str, Any]
) -> bool:
"""Check if session context has changed significantly."""
# IP address change
if context.get('ip_address') != session.ip_address:
return True
# Device fingerprint change
if (session.device_fingerprint and
context.get('device_fingerprint') and
context.get('device_fingerprint') != session.device_fingerprint):
return True
return False
def get_user_sessions(
self,
user: User,
include_expired: bool = False
) -> List[UserSession]:
"""Get all sessions for a user."""
query = self.db.query(UserSession).filter(
UserSession.user_id == user.id
)
if not include_expired:
query = query.filter(
UserSession.status == SessionStatus.ACTIVE.value
)
return query.order_by(UserSession.created_at.desc()).all()
def revoke_session(self, session_id: int, reason: str = "User requested"):
"""Revoke a specific session."""
session = self.db.query(UserSession).filter(
UserSession.id == session_id
).first()
if session:
session.status = SessionStatus.REVOKED.value
self.db.commit()
def revoke_all_sessions(self, user: User, except_session_id: Optional[int] = None):
"""Revoke all sessions for a user."""
query = self.db.query(UserSession).filter(
UserSession.user_id == user.id,
UserSession.status == SessionStatus.ACTIVE.value
)
if except_session_id:
query = query.filter(UserSession.id != except_session_id)
for session in query.all():
session.status = SessionStatus.REVOKED.value
self.db.commit()
def mark_session_suspicious(self, session: UserSession, reason: str):
"""Mark a session as suspicious."""
session.status = SessionStatus.SUSPICIOUS.value
# Log risk event
event = RiskEvent(
user_id=session.user_id,
event_type='session_suspicious',
risk_score=session.current_risk_score,
risk_level=session.current_risk_level,
ip_address=session.ip_address,
user_agent=session.user_agent,
risk_factors={'reason': reason},
action_required='step_up',
created_at=datetime.utcnow()
)
self.db.add(event)
self.db.commit()
def complete_step_up(self, session: UserSession):
"""Mark step-up authentication as completed for session."""
session.step_up_completed = True
session.status = SessionStatus.ACTIVE.value
session.current_risk_level = RiskLevel.LOW.value
self.db.commit()
def cleanup_expired_sessions(self):
"""Clean up expired sessions. Should be run periodically."""
expired = self.db.query(UserSession).filter(
UserSession.status == SessionStatus.ACTIVE.value,
UserSession.expires_at < datetime.utcnow()
).all()
for session in expired:
session.status = SessionStatus.EXPIRED.value
self.db.commit()
return len(expired)
def get_session_statistics(self, user: Optional[User] = None) -> Dict[str, Any]:
"""Get session statistics."""
query = self.db.query(UserSession)
if user:
query = query.filter(UserSession.user_id == user.id)
active = query.filter(
UserSession.status == SessionStatus.ACTIVE.value
).count()
suspicious = query.filter(
UserSession.status == SessionStatus.SUSPICIOUS.value
).count()
total = query.count()
return {
'total': total,
'active': active,
'suspicious': suspicious,
'expired': query.filter(
UserSession.status == SessionStatus.EXPIRED.value
).count(),
'revoked': query.filter(
UserSession.status == SessionStatus.REVOKED.value
).count()
}
class AnomalyDetector:
"""
Detects suspicious patterns and anomalies.
"""
def __init__(self, db: Session):
self.db = db
def detect_brute_force(
self,
ip_address: str,
window_minutes: int = 15,
threshold: int = 10
) -> Optional[AnomalyPattern]:
"""Detect brute force attack pattern."""
failed_count = self.db.query(LoginAttempt).filter(
LoginAttempt.ip_address == ip_address,
LoginAttempt.success == False,
LoginAttempt.attempted_at >= datetime.utcnow() - timedelta(minutes=window_minutes)
).count()
if failed_count >= threshold:
# Check if pattern already exists
existing = self.db.query(AnomalyPattern).filter(
AnomalyPattern.ip_address == ip_address,
AnomalyPattern.pattern_type == 'brute_force',
AnomalyPattern.is_active == True
).first()
if existing:
existing.last_detected = datetime.utcnow()
existing.pattern_data['count'] = failed_count
self.db.commit()
return existing
pattern = AnomalyPattern(
pattern_type='brute_force',
ip_address=ip_address,
severity=RiskLevel.CRITICAL.value,
confidence=min(1.0, failed_count / (threshold * 2)),
pattern_data={
'failed_attempts': failed_count,
'window_minutes': window_minutes
},
is_active=True,
first_detected=datetime.utcnow(),
last_detected=datetime.utcnow()
)
self.db.add(pattern)
self.db.commit()
return pattern
return None
def detect_credential_stuffing(
self,
ip_address: str,
window_minutes: int = 15,
unique_users_threshold: int = 5
) -> Optional[AnomalyPattern]:
"""Detect credential stuffing attack pattern."""
from sqlalchemy import distinct
# Count unique users attempted from this IP
unique_users = self.db.query(
distinct(LoginAttempt.email)
).filter(
LoginAttempt.ip_address == ip_address,
LoginAttempt.attempted_at >= datetime.utcnow() - timedelta(minutes=window_minutes)
).count()
if unique_users >= unique_users_threshold:
existing = self.db.query(AnomalyPattern).filter(
AnomalyPattern.ip_address == ip_address,
AnomalyPattern.pattern_type == 'credential_stuffing',
AnomalyPattern.is_active == True
).first()
if existing:
existing.last_detected = datetime.utcnow()
existing.pattern_data['unique_users'] = unique_users
self.db.commit()
return existing
pattern = AnomalyPattern(
pattern_type='credential_stuffing',
ip_address=ip_address,
severity=RiskLevel.CRITICAL.value,
confidence=min(1.0, unique_users / (unique_users_threshold * 2)),
pattern_data={
'unique_users': unique_users,
'window_minutes': window_minutes
},
is_active=True,
first_detected=datetime.utcnow(),
last_detected=datetime.utcnow()
)
self.db.add(pattern)
self.db.commit()
return pattern
return None
def get_active_anomalies(
self,
user_id: Optional[int] = None,
ip_address: Optional[str] = None
) -> List[AnomalyPattern]:
"""Get active anomaly patterns."""
query = self.db.query(AnomalyPattern).filter(
AnomalyPattern.is_active == True
)
if user_id:
query = query.filter(AnomalyPattern.user_id == user_id)
if ip_address:
query = query.filter(AnomalyPattern.ip_address == ip_address)
return query.order_by(AnomalyPattern.last_detected.desc()).all()
def resolve_anomaly(self, anomaly_id: int, false_positive: bool = False):
"""Resolve an anomaly pattern."""
anomaly = self.db.query(AnomalyPattern).filter(
AnomalyPattern.id == anomaly_id
).first()
if anomaly:
anomaly.is_active = False
anomaly.false_positive = false_positive
anomaly.resolved_at = datetime.utcnow()
self.db.commit()