from datetime import datetime, timedelta import logging from app.core.cache_client import get_redis from fastapi import HTTPException logger = logging.getLogger(__name__) class SocialSecurityModel: """Model for handling social login security features""" # Rate limiting constants OAUTH_RATE_LIMIT_MAX = 5 # Max OAuth attempts per IP per hour OAUTH_RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds # Failed attempt tracking OAUTH_FAILED_ATTEMPTS_MAX = 3 # Max failed OAuth attempts per IP OAUTH_FAILED_ATTEMPTS_WINDOW = 1800 # 30 minutes OAUTH_IP_LOCK_DURATION = 3600 # 1 hour lock for IP @staticmethod async def check_oauth_rate_limit(client_ip: str, provider: str) -> bool: """Check if OAuth rate limit is exceeded for IP and provider""" if not client_ip: return True # Allow if no IP provided try: redis = await get_redis() rate_key = f"oauth_rate:{client_ip}:{provider}" current_count = await redis.get(rate_key) if current_count and int(current_count) >= SocialSecurityModel.OAUTH_RATE_LIMIT_MAX: logger.warning(f"OAuth rate limit exceeded for IP {client_ip} and provider {provider}") return False return True except Exception as e: logger.error(f"Error checking OAuth rate limit: {str(e)}", exc_info=True) return True # Allow on error to avoid blocking legitimate users @staticmethod async def increment_oauth_rate_limit(client_ip: str, provider: str): """Increment OAuth rate limit counter""" if not client_ip: return try: redis = await get_redis() rate_key = f"oauth_rate:{client_ip}:{provider}" count = await redis.incr(rate_key) if count == 1: await redis.expire(rate_key, SocialSecurityModel.OAUTH_RATE_LIMIT_WINDOW) logger.info(f"OAuth rate limit count for {client_ip}:{provider} = {count}") except Exception as e: logger.error(f"Error incrementing OAuth rate limit: {str(e)}", exc_info=True) @staticmethod async def track_oauth_failed_attempt(client_ip: str, provider: str): """Track failed OAuth verification attempts""" if not client_ip: return try: redis = await get_redis() failed_key = f"oauth_failed:{client_ip}:{provider}" attempts = await redis.incr(failed_key) if attempts == 1: await redis.expire(failed_key, SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_WINDOW) # Lock IP if too many failed attempts if attempts >= SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_MAX: await SocialSecurityModel.lock_oauth_ip(client_ip, provider) logger.warning(f"IP {client_ip} locked for provider {provider} after {attempts} failed attempts") logger.info(f"OAuth failed attempts for {client_ip}:{provider} = {attempts}") except Exception as e: logger.error(f"Error tracking OAuth failed attempt: {str(e)}", exc_info=True) @staticmethod async def lock_oauth_ip(client_ip: str, provider: str): """Lock IP for OAuth attempts on specific provider""" try: redis = await get_redis() lock_key = f"oauth_ip_locked:{client_ip}:{provider}" await redis.setex(lock_key, SocialSecurityModel.OAUTH_IP_LOCK_DURATION, "locked") logger.info(f"IP {client_ip} locked for OAuth provider {provider}") except Exception as e: logger.error(f"Error locking OAuth IP: {str(e)}", exc_info=True) @staticmethod async def is_oauth_ip_locked(client_ip: str, provider: str) -> bool: """Check if IP is locked for OAuth attempts on specific provider""" if not client_ip: return False try: redis = await get_redis() lock_key = f"oauth_ip_locked:{client_ip}:{provider}" locked = await redis.get(lock_key) return locked is not None except Exception as e: logger.error(f"Error checking OAuth IP lock: {str(e)}", exc_info=True) return False @staticmethod async def clear_oauth_failed_attempts(client_ip: str, provider: str): """Clear failed OAuth attempts on successful verification""" if not client_ip: return try: redis = await get_redis() failed_key = f"oauth_failed:{client_ip}:{provider}" await redis.delete(failed_key) logger.info(f"Cleared OAuth failed attempts for {client_ip}:{provider}") except Exception as e: logger.error(f"Error clearing OAuth failed attempts: {str(e)}", exc_info=True) @staticmethod async def validate_oauth_token_format(token: str, provider: str) -> bool: """Basic validation of OAuth token format""" # In local test mode, accept any non-empty string to facilitate testing try: from app.core.config import settings if getattr(settings, "OAUTH_TEST_MODE", False): return bool(token) except Exception: pass if not token or not isinstance(token, str): return False # Basic length and format checks if provider == "google": # Normalize optional Bearer prefix t = token.strip() if t.lower().startswith("bearer "): t = t[7:] # Accept Google ID tokens (JWT) if len(t) > 100 and t.count('.') == 2: return True # Accept Google OAuth access tokens (commonly start with 'ya29.') and are shorter if t.startswith("ya29.") or (len(t) >= 20 and len(t) <= 4096 and t.count('.') < 2): return True return False elif provider == "apple": # Apple ID tokens are also JWT format return len(token) > 100 and token.count('.') == 2 elif provider == "facebook": # Facebook access tokens are typically shorter return len(token) > 20 and len(token) < 500 return True # Allow unknown providers @staticmethod async def log_oauth_attempt(client_ip: str, provider: str, success: bool, customer_id: str = None): """Log OAuth authentication attempts for security monitoring""" try: redis = await get_redis() log_key = f"oauth_log:{datetime.utcnow().strftime('%Y-%m-%d')}" log_entry = { "timestamp": datetime.utcnow().isoformat(), "ip": client_ip, "provider": provider, "success": success, "customer_id": customer_id } # Store as JSON string in Redis list import json await redis.lpush(log_key, json.dumps(log_entry)) # Keep only last 1000 entries per day await redis.ltrim(log_key, 0, 999) # Set expiry for 30 days await redis.expire(log_key, 30 * 24 * 3600) logger.info(f"OAuth attempt logged: {provider} from {client_ip} - {'success' if success else 'failed'}") except Exception as e: logger.error(f"Error logging OAuth attempt: {str(e)}", exc_info=True)