Spaces:
Sleeping
Sleeping
| """ | |
| Security Middleware | |
| Implements rate limiting, input sanitization, CSRF protection, and security headers | |
| SECURITY REQUIREMENTS FOR NEW FEATURES: | |
| ======================================== | |
| 1. INPUT VALIDATION: | |
| - Always use sanitize_input() for text fields (titles, descriptions, comments) | |
| - Use sanitize_dict() for JSON payloads | |
| - Validate file uploads: check file type, size, and scan content | |
| - Validate ObjectIds before database queries with validate_object_id() | |
| 2. AUTHENTICATION & AUTHORIZATION: | |
| - Use Depends(get_current_user) for user-only endpoints | |
| - Use Depends(get_current_admin) for admin-only endpoints | |
| - Never expose user data without authentication | |
| - Check ownership before allowing updates/deletes | |
| 3. RATE LIMITING: | |
| - Apply @limiter.limit() to all write endpoints | |
| - Use stricter limits for sensitive operations (login, registration, money transfers) | |
| - Example: @limiter.limit("5/minute") for login | |
| 4. DATA MASKING: | |
| - Use mask_email(), mask_phone(), mask_sensitive_data() for PII | |
| - Default to masked view, require explicit permission for full data | |
| - Log access to unmasked sensitive data | |
| 5. ERROR HANDLING: | |
| - Never expose stack traces or internal details to users | |
| - Use HTTPException with sanitized messages | |
| - Log full errors server-side with request ID | |
| - Return consistent error format | |
| 6. LOGGING: | |
| - Log all security events (failed logins, access denials, suspicious activity) | |
| - Include request ID in all logs for tracing | |
| - Never log passwords, tokens, or sensitive PII | |
| 7. NEW FEATURE CHECKLIST: | |
| β Input sanitization implemented | |
| β Authentication/authorization configured | |
| β Rate limiting applied | |
| β PII data masked | |
| β Error handling prevents information leakage | |
| β Security logging added | |
| β Unit tests for security scenarios written | |
| β Penetration testing performed | |
| EXAMPLES: | |
| --------- | |
| Chat/Comments Feature: | |
| - Sanitize message content: sanitize_input(message.content) | |
| - Rate limit: @limiter.limit("10/minute") for sending messages | |
| - Authenticate: current_user = Depends(get_current_user) | |
| - Validate: max message length, blocked words list | |
| - Mask: user email/phone in chat metadata | |
| File Upload Feature: | |
| - Validate file type: allowed_types = ['pdf', 'jpg', 'png'] | |
| - Validate file size: max_size = 10 * 1024 * 1024 # 10MB | |
| - Scan content: virus scan, malicious code detection | |
| - Sanitize filename: remove path traversal characters | |
| - Store securely: use IPFS or encrypted storage | |
| """ | |
| from fastapi import Request, HTTPException, status | |
| from fastapi.responses import JSONResponse | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.datastructures import Headers | |
| import bleach | |
| import re | |
| from typing import Dict, Any | |
| import secrets | |
| from datetime import datetime, timedelta | |
| import uuid | |
| import logging | |
| # Initialize rate limiter | |
| limiter = Limiter(key_func=get_remote_address) | |
| # Configure logger | |
| logger = logging.getLogger(__name__) | |
| # Failed login attempts tracking (in-memory - replace with Redis in production) | |
| failed_login_attempts: Dict[str, Dict[str, Any]] = {} | |
| # CSRF token storage (in-memory - replace with Redis in production) | |
| csrf_tokens: Dict[str, datetime] = {} | |
| class SecurityHeadersMiddleware(BaseHTTPMiddleware): | |
| """Add security headers to all responses""" | |
| async def dispatch(self, request: Request, call_next): | |
| response = await call_next(request) | |
| # Security Headers | |
| response.headers["X-Content-Type-Options"] = "nosniff" | |
| # Allow iframe embedding on Hugging Face Spaces | |
| response.headers["X-Frame-Options"] = "ALLOWALL" | |
| response.headers["X-XSS-Protection"] = "1; mode=block" | |
| response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" | |
| response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" | |
| # Content Security Policy - allow HF Spaces iframe | |
| csp = ( | |
| "default-src 'self'; " | |
| "script-src 'self' 'unsafe-inline' 'unsafe-eval'; " | |
| "style-src 'self' 'unsafe-inline'; " | |
| "img-src 'self' data: https:; " | |
| "font-src 'self' data:; " | |
| "connect-src 'self' https://s.altnet.rippletest.net:51234 https://testnet.xrpl.org https://*.huggingface.co https://*.hf.space; " | |
| "frame-ancestors 'self' https://huggingface.co https://*.huggingface.co https://*.hf.space;" | |
| ) | |
| response.headers["Content-Security-Policy"] = csp | |
| return response | |
| class RequestIDMiddleware(BaseHTTPMiddleware): | |
| """Add unique request ID to each request for tracking and debugging""" | |
| async def dispatch(self, request: Request, call_next): | |
| # Generate unique request ID | |
| request_id = str(uuid.uuid4()) | |
| # Store request ID in request state for access in route handlers | |
| request.state.request_id = request_id | |
| # Log request details | |
| logger.info(f"[{request_id}] {request.method} {request.url.path} - Client: {request.client.host if request.client else 'unknown'}") | |
| try: | |
| response = await call_next(request) | |
| # Add request ID to response headers for client tracking | |
| response.headers["X-Request-ID"] = request_id | |
| logger.info(f"[{request_id}] Response status: {response.status_code}") | |
| return response | |
| except Exception as e: | |
| # Log error with request ID | |
| logger.error(f"[{request_id}] Error: {str(e)}", exc_info=True) | |
| # Return error response without exposing internal details | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "detail": "Internal server error", | |
| "request_id": request_id | |
| }, | |
| headers={"X-Request-ID": request_id} | |
| ) | |
| def sanitize_input(text: str) -> str: | |
| """Sanitize user input to prevent XSS attacks""" | |
| if not isinstance(text, str): | |
| return text | |
| # Remove HTML tags and attributes | |
| cleaned = bleach.clean( | |
| text, | |
| tags=[], # No HTML tags allowed | |
| attributes={}, | |
| strip=True | |
| ) | |
| return cleaned.strip() | |
| def sanitize_dict(data: dict) -> dict: | |
| """Recursively sanitize all string values in a dictionary""" | |
| if not isinstance(data, dict): | |
| return data | |
| sanitized = {} | |
| for key, value in data.items(): | |
| if isinstance(value, str): | |
| sanitized[key] = sanitize_input(value) | |
| elif isinstance(value, dict): | |
| sanitized[key] = sanitize_dict(value) | |
| elif isinstance(value, list): | |
| sanitized[key] = [ | |
| sanitize_input(item) if isinstance(item, str) else item | |
| for item in value | |
| ] | |
| else: | |
| sanitized[key] = value | |
| return sanitized | |
| def validate_object_id(id_string: str) -> bool: | |
| """Validate MongoDB ObjectId format""" | |
| if not isinstance(id_string, str): | |
| return False | |
| # ObjectId is 24 character hexadecimal string | |
| pattern = re.compile(r'^[0-9a-fA-F]{24}$') | |
| return bool(pattern.match(id_string)) | |
| def validate_name(name: str) -> tuple[bool, str]: | |
| """ | |
| Validate and sanitize user name | |
| Returns: (is_valid, error_message or sanitized_name) | |
| """ | |
| if not name: | |
| return False, "Name is required" | |
| # Sanitize first | |
| name = sanitize_input(name).strip() | |
| # Check length | |
| if len(name) < 2: | |
| return False, "Name must be at least 2 characters" | |
| if len(name) > 100: | |
| return False, "Name must not exceed 100 characters" | |
| # Only allow letters, spaces, hyphens, and apostrophes (strict) | |
| pattern = re.compile(r"^[a-zA-Z\s\-']+$") | |
| if not pattern.match(name): | |
| return False, "Name can only contain letters, spaces, hyphens, and apostrophes" | |
| # Check for excessive spaces | |
| if ' ' in name: | |
| return False, "Name cannot contain multiple consecutive spaces" | |
| return True, name | |
| def validate_phone(phone: str) -> tuple[bool, str]: | |
| """ | |
| Validate and sanitize phone number (REQUIRED, exactly 10 digits) | |
| Returns: (is_valid, error_message or sanitized_phone) | |
| """ | |
| if not phone: | |
| return False, "Phone number is required" | |
| # Sanitize first | |
| phone = sanitize_input(phone).strip() | |
| # Remove any non-digit characters for validation | |
| digits_only = re.sub(r'\D', '', phone) | |
| # Must be exactly 10 digits | |
| if len(digits_only) != 10: | |
| return False, "Phone number must be exactly 10 digits" | |
| # Only allow pure digits (no formatting characters) | |
| pattern = re.compile(r'^[0-9]{10}$') | |
| if not pattern.match(phone): | |
| return False, "Phone number must contain only 10 digits (no spaces or special characters)" | |
| return True, phone | |
| def validate_date(date_str: str) -> tuple[bool, str]: | |
| """ | |
| Validate date string (YYYY-MM-DD format) | |
| Returns: (is_valid, error_message or sanitized_date) | |
| """ | |
| if not date_str: | |
| return True, None # Date is optional | |
| # Sanitize first | |
| date_str = sanitize_input(date_str).strip() | |
| # Check format | |
| pattern = re.compile(r'^\d{4}-\d{2}-\d{2}$') | |
| if not pattern.match(date_str): | |
| return False, "Date must be in YYYY-MM-DD format" | |
| # Try to parse date | |
| try: | |
| date_obj = datetime.strptime(date_str, '%Y-%m-%d') | |
| # Check if date is not in future | |
| if date_obj > datetime.now(): | |
| return False, "Date cannot be in the future" | |
| return True, date_str | |
| except ValueError: | |
| return False, "Invalid date" | |
| def validate_gender(gender: str) -> tuple[bool, str]: | |
| """ | |
| Validate gender selection | |
| Returns: (is_valid, error_message or sanitized_gender) | |
| """ | |
| if not gender: | |
| return True, None # Gender is optional | |
| # Sanitize first | |
| gender = sanitize_input(gender).strip() | |
| # Must be from predefined list | |
| valid_genders = ['Male', 'Female', 'Other'] | |
| if gender not in valid_genders: | |
| return False, f"Gender must be one of: {', '.join(valid_genders)}" | |
| return True, gender | |
| def validate_address(address: str) -> tuple[bool, str]: | |
| """ | |
| Validate and sanitize address | |
| Returns: (is_valid, error_message or sanitized_address) | |
| """ | |
| if not address: | |
| return True, "" # Address is optional | |
| # Sanitize first | |
| address = sanitize_input(address).strip() | |
| # Check length | |
| if len(address) > 500: | |
| return False, "Address must not exceed 500 characters" | |
| # Only allow letters, numbers, spaces, and common address characters | |
| pattern = re.compile(r"^[a-zA-Z0-9\s,.\-'#/()\n]+$") | |
| if not pattern.match(address): | |
| return False, "Address contains invalid characters" | |
| return True, address | |
| def validate_redirect_url(url: str, allowed_domains: list = None) -> bool: | |
| """ | |
| Validate redirect URL to prevent open redirect attacks | |
| Only allows relative URLs or URLs from whitelisted domains | |
| Blocks javascript:, data:, and other dangerous schemes | |
| """ | |
| if not url: | |
| return False | |
| # Block dangerous schemes | |
| dangerous_schemes = ['javascript:', 'data:', 'vbscript:', 'file:', 'about:'] | |
| url_lower = url.lower().strip() | |
| for scheme in dangerous_schemes: | |
| if url_lower.startswith(scheme): | |
| return False | |
| # Default allowed domains (localhost and local dev) | |
| if allowed_domains is None: | |
| allowed_domains = [ | |
| 'localhost', | |
| '127.0.0.1', | |
| 'localhost:5173', | |
| 'localhost:5174', | |
| 'localhost:5175', | |
| '127.0.0.1:5173', | |
| '127.0.0.1:5174', | |
| '127.0.0.1:5175' | |
| ] | |
| # Check if URL is relative (starts with /) | |
| if url.startswith('/') and not url.startswith('//'): | |
| return True | |
| # Check if URL starts with allowed domain | |
| for domain in allowed_domains: | |
| if url.startswith(f'http://{domain}') or url.startswith(f'https://{domain}'): | |
| return True | |
| # Reject all other URLs (external domains) | |
| return False | |
| def generate_csrf_token() -> str: | |
| """Generate a CSRF token""" | |
| token = secrets.token_urlsafe(32) | |
| csrf_tokens[token] = datetime.utcnow() + timedelta(hours=1) | |
| return token | |
| def validate_csrf_token(token: str) -> bool: | |
| """Validate CSRF token""" | |
| if not token or token not in csrf_tokens: | |
| return False | |
| # Check if token is expired | |
| if csrf_tokens[token] < datetime.utcnow(): | |
| del csrf_tokens[token] | |
| return False | |
| return True | |
| def check_rate_limit(ip: str, endpoint: str, max_attempts: int = 5, window_minutes: int = 15) -> bool: | |
| """ | |
| Check if IP has exceeded rate limit for failed login attempts | |
| Returns True if allowed, False if blocked | |
| """ | |
| key = f"{ip}:{endpoint}" | |
| now = datetime.utcnow() | |
| if key not in failed_login_attempts: | |
| failed_login_attempts[key] = { | |
| 'count': 0, | |
| 'first_attempt': now, | |
| 'locked_until': None | |
| } | |
| attempt_data = failed_login_attempts[key] | |
| # Check if currently locked | |
| if attempt_data['locked_until'] and attempt_data['locked_until'] > now: | |
| return False | |
| # Reset if window has passed | |
| if now - attempt_data['first_attempt'] > timedelta(minutes=window_minutes): | |
| failed_login_attempts[key] = { | |
| 'count': 0, | |
| 'first_attempt': now, | |
| 'locked_until': None | |
| } | |
| return True | |
| # Check if exceeded max attempts | |
| if attempt_data['count'] >= max_attempts: | |
| # Lock for 15 minutes | |
| attempt_data['locked_until'] = now + timedelta(minutes=window_minutes) | |
| return False | |
| return True | |
| def record_failed_attempt(ip: str, endpoint: str): | |
| """Record a failed login attempt""" | |
| key = f"{ip}:{endpoint}" | |
| now = datetime.utcnow() | |
| if key not in failed_login_attempts: | |
| failed_login_attempts[key] = { | |
| 'count': 1, | |
| 'first_attempt': now, | |
| 'locked_until': None | |
| } | |
| else: | |
| failed_login_attempts[key]['count'] += 1 | |
| def reset_failed_attempts(ip: str, endpoint: str): | |
| """Reset failed attempts after successful login""" | |
| key = f"{ip}:{endpoint}" | |
| if key in failed_login_attempts: | |
| del failed_login_attempts[key] | |
| def mask_sensitive_data(data: str, mask_char: str = "*", visible_chars: int = 4) -> str: | |
| """ | |
| Mask sensitive data showing only last N characters | |
| Shows fixed-length mask (12 chars) for consistent display | |
| """ | |
| if not data or len(data) <= visible_chars: | |
| return data | |
| # Use fixed mask length for consistent display (prevents length-based attacks) | |
| fixed_mask_length = 12 | |
| return mask_char * fixed_mask_length + data[-visible_chars:] | |
| def mask_email(email: str) -> str: | |
| """Mask email address preserving domain""" | |
| if not email or '@' not in email: | |
| return email | |
| local, domain = email.split('@', 1) | |
| if len(local) <= 2: | |
| return f"{local}@{domain}" | |
| masked_local = local[0] + '*' * (len(local) - 2) + local[-1] | |
| return f"{masked_local}@{domain}" | |
| def mask_phone(phone: str) -> str: | |
| """Mask phone number showing only last 4 digits""" | |
| if not phone: | |
| return phone | |
| # Remove non-numeric characters | |
| digits_only = re.sub(r'\D', '', phone) | |
| if len(digits_only) <= 4: | |
| return phone | |
| return '*' * (len(digits_only) - 4) + digits_only[-4:] | |