bookmyservice-ums / app /models /social_security_model.py
MukeshKapoor25's picture
feat(oauth): add local testing mode to bypass external verification
2fc2b48
from datetime import datetime, timedelta
import logging
from app.core.cache_client import get_redis
from fastapi import HTTPException
logger = logging.getLogger(__name__)
class SocialSecurityModel:
"""Model for handling social login security features"""
# Rate limiting constants
OAUTH_RATE_LIMIT_MAX = 5 # Max OAuth attempts per IP per hour
OAUTH_RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds
# Failed attempt tracking
OAUTH_FAILED_ATTEMPTS_MAX = 3 # Max failed OAuth attempts per IP
OAUTH_FAILED_ATTEMPTS_WINDOW = 1800 # 30 minutes
OAUTH_IP_LOCK_DURATION = 3600 # 1 hour lock for IP
@staticmethod
async def check_oauth_rate_limit(client_ip: str, provider: str) -> bool:
"""Check if OAuth rate limit is exceeded for IP and provider"""
if not client_ip:
return True # Allow if no IP provided
try:
redis = await get_redis()
rate_key = f"oauth_rate:{client_ip}:{provider}"
current_count = await redis.get(rate_key)
if current_count and int(current_count) >= SocialSecurityModel.OAUTH_RATE_LIMIT_MAX:
logger.warning(f"OAuth rate limit exceeded for IP {client_ip} and provider {provider}")
return False
return True
except Exception as e:
logger.error(f"Error checking OAuth rate limit: {str(e)}", exc_info=True)
return True # Allow on error to avoid blocking legitimate users
@staticmethod
async def increment_oauth_rate_limit(client_ip: str, provider: str):
"""Increment OAuth rate limit counter"""
if not client_ip:
return
try:
redis = await get_redis()
rate_key = f"oauth_rate:{client_ip}:{provider}"
count = await redis.incr(rate_key)
if count == 1:
await redis.expire(rate_key, SocialSecurityModel.OAUTH_RATE_LIMIT_WINDOW)
logger.info(f"OAuth rate limit count for {client_ip}:{provider} = {count}")
except Exception as e:
logger.error(f"Error incrementing OAuth rate limit: {str(e)}", exc_info=True)
@staticmethod
async def track_oauth_failed_attempt(client_ip: str, provider: str):
"""Track failed OAuth verification attempts"""
if not client_ip:
return
try:
redis = await get_redis()
failed_key = f"oauth_failed:{client_ip}:{provider}"
attempts = await redis.incr(failed_key)
if attempts == 1:
await redis.expire(failed_key, SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_WINDOW)
# Lock IP if too many failed attempts
if attempts >= SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_MAX:
await SocialSecurityModel.lock_oauth_ip(client_ip, provider)
logger.warning(f"IP {client_ip} locked for provider {provider} after {attempts} failed attempts")
logger.info(f"OAuth failed attempts for {client_ip}:{provider} = {attempts}")
except Exception as e:
logger.error(f"Error tracking OAuth failed attempt: {str(e)}", exc_info=True)
@staticmethod
async def lock_oauth_ip(client_ip: str, provider: str):
"""Lock IP for OAuth attempts on specific provider"""
try:
redis = await get_redis()
lock_key = f"oauth_ip_locked:{client_ip}:{provider}"
await redis.setex(lock_key, SocialSecurityModel.OAUTH_IP_LOCK_DURATION, "locked")
logger.info(f"IP {client_ip} locked for OAuth provider {provider}")
except Exception as e:
logger.error(f"Error locking OAuth IP: {str(e)}", exc_info=True)
@staticmethod
async def is_oauth_ip_locked(client_ip: str, provider: str) -> bool:
"""Check if IP is locked for OAuth attempts on specific provider"""
if not client_ip:
return False
try:
redis = await get_redis()
lock_key = f"oauth_ip_locked:{client_ip}:{provider}"
locked = await redis.get(lock_key)
return locked is not None
except Exception as e:
logger.error(f"Error checking OAuth IP lock: {str(e)}", exc_info=True)
return False
@staticmethod
async def clear_oauth_failed_attempts(client_ip: str, provider: str):
"""Clear failed OAuth attempts on successful verification"""
if not client_ip:
return
try:
redis = await get_redis()
failed_key = f"oauth_failed:{client_ip}:{provider}"
await redis.delete(failed_key)
logger.info(f"Cleared OAuth failed attempts for {client_ip}:{provider}")
except Exception as e:
logger.error(f"Error clearing OAuth failed attempts: {str(e)}", exc_info=True)
@staticmethod
async def validate_oauth_token_format(token: str, provider: str) -> bool:
"""Basic validation of OAuth token format"""
# In local test mode, accept any non-empty string to facilitate testing
try:
from app.core.config import settings
if getattr(settings, "OAUTH_TEST_MODE", False):
return bool(token)
except Exception:
pass
if not token or not isinstance(token, str):
return False
# Basic length and format checks
if provider == "google":
# Normalize optional Bearer prefix
t = token.strip()
if t.lower().startswith("bearer "):
t = t[7:]
# Accept Google ID tokens (JWT)
if len(t) > 100 and t.count('.') == 2:
return True
# Accept Google OAuth access tokens (commonly start with 'ya29.') and are shorter
if t.startswith("ya29.") or (len(t) >= 20 and len(t) <= 4096 and t.count('.') < 2):
return True
return False
elif provider == "apple":
# Apple ID tokens are also JWT format
return len(token) > 100 and token.count('.') == 2
elif provider == "facebook":
# Facebook access tokens are typically shorter
return len(token) > 20 and len(token) < 500
return True # Allow unknown providers
@staticmethod
async def log_oauth_attempt(client_ip: str, provider: str, success: bool, customer_id: str = None):
"""Log OAuth authentication attempts for security monitoring"""
try:
redis = await get_redis()
log_key = f"oauth_log:{datetime.utcnow().strftime('%Y-%m-%d')}"
log_entry = {
"timestamp": datetime.utcnow().isoformat(),
"ip": client_ip,
"provider": provider,
"success": success,
"customer_id": customer_id
}
# Store as JSON string in Redis list
import json
await redis.lpush(log_key, json.dumps(log_entry))
# Keep only last 1000 entries per day
await redis.ltrim(log_key, 0, 999)
# Set expiry for 30 days
await redis.expire(log_key, 30 * 24 * 3600)
logger.info(f"OAuth attempt logged: {provider} from {client_ip} - {'success' if success else 'failed'}")
except Exception as e:
logger.error(f"Error logging OAuth attempt: {str(e)}", exc_info=True)