Spaces:
Sleeping
Sleeping
| from datetime import datetime, timedelta | |
| import logging | |
| from app.core.cache_client import get_redis | |
| from fastapi import HTTPException | |
| logger = logging.getLogger(__name__) | |
| class SocialSecurityModel: | |
| """Model for handling social login security features""" | |
| # Rate limiting constants | |
| OAUTH_RATE_LIMIT_MAX = 5 # Max OAuth attempts per IP per hour | |
| OAUTH_RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds | |
| # Failed attempt tracking | |
| OAUTH_FAILED_ATTEMPTS_MAX = 3 # Max failed OAuth attempts per IP | |
| OAUTH_FAILED_ATTEMPTS_WINDOW = 1800 # 30 minutes | |
| OAUTH_IP_LOCK_DURATION = 3600 # 1 hour lock for IP | |
| async def check_oauth_rate_limit(client_ip: str, provider: str) -> bool: | |
| """Check if OAuth rate limit is exceeded for IP and provider""" | |
| if not client_ip: | |
| return True # Allow if no IP provided | |
| try: | |
| redis = await get_redis() | |
| rate_key = f"oauth_rate:{client_ip}:{provider}" | |
| current_count = await redis.get(rate_key) | |
| if current_count and int(current_count) >= SocialSecurityModel.OAUTH_RATE_LIMIT_MAX: | |
| logger.warning(f"OAuth rate limit exceeded for IP {client_ip} and provider {provider}") | |
| return False | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error checking OAuth rate limit: {str(e)}", exc_info=True) | |
| return True # Allow on error to avoid blocking legitimate users | |
| async def increment_oauth_rate_limit(client_ip: str, provider: str): | |
| """Increment OAuth rate limit counter""" | |
| if not client_ip: | |
| return | |
| try: | |
| redis = await get_redis() | |
| rate_key = f"oauth_rate:{client_ip}:{provider}" | |
| count = await redis.incr(rate_key) | |
| if count == 1: | |
| await redis.expire(rate_key, SocialSecurityModel.OAUTH_RATE_LIMIT_WINDOW) | |
| logger.info(f"OAuth rate limit count for {client_ip}:{provider} = {count}") | |
| except Exception as e: | |
| logger.error(f"Error incrementing OAuth rate limit: {str(e)}", exc_info=True) | |
| async def track_oauth_failed_attempt(client_ip: str, provider: str): | |
| """Track failed OAuth verification attempts""" | |
| if not client_ip: | |
| return | |
| try: | |
| redis = await get_redis() | |
| failed_key = f"oauth_failed:{client_ip}:{provider}" | |
| attempts = await redis.incr(failed_key) | |
| if attempts == 1: | |
| await redis.expire(failed_key, SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_WINDOW) | |
| # Lock IP if too many failed attempts | |
| if attempts >= SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_MAX: | |
| await SocialSecurityModel.lock_oauth_ip(client_ip, provider) | |
| logger.warning(f"IP {client_ip} locked for provider {provider} after {attempts} failed attempts") | |
| logger.info(f"OAuth failed attempts for {client_ip}:{provider} = {attempts}") | |
| except Exception as e: | |
| logger.error(f"Error tracking OAuth failed attempt: {str(e)}", exc_info=True) | |
| async def lock_oauth_ip(client_ip: str, provider: str): | |
| """Lock IP for OAuth attempts on specific provider""" | |
| try: | |
| redis = await get_redis() | |
| lock_key = f"oauth_ip_locked:{client_ip}:{provider}" | |
| await redis.setex(lock_key, SocialSecurityModel.OAUTH_IP_LOCK_DURATION, "locked") | |
| logger.info(f"IP {client_ip} locked for OAuth provider {provider}") | |
| except Exception as e: | |
| logger.error(f"Error locking OAuth IP: {str(e)}", exc_info=True) | |
| async def is_oauth_ip_locked(client_ip: str, provider: str) -> bool: | |
| """Check if IP is locked for OAuth attempts on specific provider""" | |
| if not client_ip: | |
| return False | |
| try: | |
| redis = await get_redis() | |
| lock_key = f"oauth_ip_locked:{client_ip}:{provider}" | |
| locked = await redis.get(lock_key) | |
| return locked is not None | |
| except Exception as e: | |
| logger.error(f"Error checking OAuth IP lock: {str(e)}", exc_info=True) | |
| return False | |
| async def clear_oauth_failed_attempts(client_ip: str, provider: str): | |
| """Clear failed OAuth attempts on successful verification""" | |
| if not client_ip: | |
| return | |
| try: | |
| redis = await get_redis() | |
| failed_key = f"oauth_failed:{client_ip}:{provider}" | |
| await redis.delete(failed_key) | |
| logger.info(f"Cleared OAuth failed attempts for {client_ip}:{provider}") | |
| except Exception as e: | |
| logger.error(f"Error clearing OAuth failed attempts: {str(e)}", exc_info=True) | |
| async def validate_oauth_token_format(token: str, provider: str) -> bool: | |
| """Basic validation of OAuth token format""" | |
| # In local test mode, accept any non-empty string to facilitate testing | |
| try: | |
| from app.core.config import settings | |
| if getattr(settings, "OAUTH_TEST_MODE", False): | |
| return bool(token) | |
| except Exception: | |
| pass | |
| if not token or not isinstance(token, str): | |
| return False | |
| # Basic length and format checks | |
| if provider == "google": | |
| # Normalize optional Bearer prefix | |
| t = token.strip() | |
| if t.lower().startswith("bearer "): | |
| t = t[7:] | |
| # Accept Google ID tokens (JWT) | |
| if len(t) > 100 and t.count('.') == 2: | |
| return True | |
| # Accept Google OAuth access tokens (commonly start with 'ya29.') and are shorter | |
| if t.startswith("ya29.") or (len(t) >= 20 and len(t) <= 4096 and t.count('.') < 2): | |
| return True | |
| return False | |
| elif provider == "apple": | |
| # Apple ID tokens are also JWT format | |
| return len(token) > 100 and token.count('.') == 2 | |
| elif provider == "facebook": | |
| # Facebook access tokens are typically shorter | |
| return len(token) > 20 and len(token) < 500 | |
| return True # Allow unknown providers | |
| async def log_oauth_attempt(client_ip: str, provider: str, success: bool, customer_id: str = None): | |
| """Log OAuth authentication attempts for security monitoring""" | |
| try: | |
| redis = await get_redis() | |
| log_key = f"oauth_log:{datetime.utcnow().strftime('%Y-%m-%d')}" | |
| log_entry = { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "ip": client_ip, | |
| "provider": provider, | |
| "success": success, | |
| "customer_id": customer_id | |
| } | |
| # Store as JSON string in Redis list | |
| import json | |
| await redis.lpush(log_key, json.dumps(log_entry)) | |
| # Keep only last 1000 entries per day | |
| await redis.ltrim(log_key, 0, 999) | |
| # Set expiry for 30 days | |
| await redis.expire(log_key, 30 * 24 * 3600) | |
| logger.info(f"OAuth attempt logged: {provider} from {client_ip} - {'success' if success else 'failed'}") | |
| except Exception as e: | |
| logger.error(f"Error logging OAuth attempt: {str(e)}", exc_info=True) |