Spaces:
Sleeping
Sleeping
| from fastapi import Request, Response | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.responses import JSONResponse | |
| import uuid | |
| from typing import Callable | |
| from ..database import db_manager | |
| class SessionMiddleware(BaseHTTPMiddleware): | |
| """Middleware to handle session-based database management""" | |
| def __init__(self, app, require_database: bool = True): | |
| super().__init__(app) | |
| self.require_database = require_database | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| # Skip validation for OPTIONS requests (CORS preflight) | |
| if request.method == "OPTIONS": | |
| response = await call_next(request) | |
| return response | |
| # Get or generate session ID | |
| session_id = request.headers.get('x-session-id') | |
| if not session_id: | |
| session_id = str(uuid.uuid4()) | |
| # Add session ID to request state | |
| request.state.session_id = session_id | |
| # Check if this is a database-related endpoint | |
| path = request.url.path | |
| is_database_endpoint = ( | |
| path.startswith('/settings/') or | |
| path.startswith('/customer/api/') or | |
| path.startswith('/chef/') or | |
| path.startswith('/admin/') or | |
| path.startswith('/analytics/') or | |
| path.startswith('/tables/') or | |
| path.startswith('/feedback/') or | |
| path.startswith('/loyalty/') or | |
| path.startswith('/selection-offers/') | |
| ) | |
| # Skip session validation for certain endpoints | |
| skip_validation_endpoints = [ | |
| '/settings/databases', | |
| '/settings/hotels', | |
| '/settings/switch-database', | |
| '/settings/switch-hotel', | |
| '/settings/current-database', | |
| '/settings/current-hotel' | |
| ] | |
| # Skip validation for admin and chef routes - they handle their own database selection | |
| skip_validation_paths = [ | |
| '/admin/', | |
| '/chef/' | |
| ] | |
| # Check if path should skip validation | |
| should_skip_path = any(path.startswith(skip_path) for skip_path in skip_validation_paths) | |
| should_validate = ( | |
| is_database_endpoint and | |
| path not in skip_validation_endpoints and | |
| not should_skip_path and | |
| self.require_database | |
| ) | |
| if should_validate: | |
| # Check if session has a valid hotel context | |
| current_hotel_id = db_manager.get_current_hotel_id(session_id) | |
| if not current_hotel_id: | |
| # Check if there's stored hotel credentials in headers | |
| stored_hotel_name = request.headers.get('x-hotel-name') | |
| stored_password = request.headers.get('x-hotel-password') | |
| if stored_hotel_name and stored_password: | |
| # Try to verify and set hotel context | |
| try: | |
| # Authenticate hotel using the database manager | |
| hotel_id = db_manager.authenticate_hotel(stored_hotel_name, stored_password) | |
| if hotel_id: | |
| # Valid credentials, set hotel context | |
| db_manager.set_hotel_context(session_id, hotel_id) | |
| else: | |
| # Invalid credentials | |
| return JSONResponse( | |
| status_code=401, | |
| content={ | |
| "detail": "Invalid hotel credentials", | |
| "error_code": "HOTEL_AUTH_FAILED" | |
| } | |
| ) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "detail": f"Hotel authentication failed: {str(e)}", | |
| "error_code": "HOTEL_VERIFICATION_ERROR" | |
| } | |
| ) | |
| else: | |
| # No hotel selected | |
| return JSONResponse( | |
| status_code=400, | |
| content={ | |
| "detail": "No hotel selected. Please select a hotel first.", | |
| "error_code": "HOTEL_NOT_SELECTED" | |
| } | |
| ) | |
| # Process the request | |
| response = await call_next(request) | |
| # Add session ID to response headers | |
| response.headers["x-session-id"] = session_id | |
| return response | |
| def get_session_id(request: Request) -> str: | |
| """Helper function to get session ID from request""" | |
| return getattr(request.state, 'session_id', str(uuid.uuid4())) | |