Spaces:
Running
Running
Update .gitignore to include frontend and Lung Cancer Guidelines directories. Enhance CORS middleware configuration to allow origins from environment variables, with defaults for local development. Modify session cookie settings to require secure flag for SameSite=None, improving security for cross-site requests.
0f194af
| """ | |
| Middleware for Medical RAG AI Advisor API | |
| """ | |
| import time | |
| import logging | |
| from typing import Callable, Awaitable, Optional | |
| from fastapi import Request, Response, HTTPException, Cookie | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| logger = logging.getLogger(__name__) | |
| class ProcessTimeMiddleware(BaseHTTPMiddleware): | |
| """Middleware to add processing time to response headers""" | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| response.headers["X-Process-Time"] = f"{process_time:.4f}" | |
| return response | |
| class LoggingMiddleware(BaseHTTPMiddleware): | |
| """Middleware for request/response logging""" | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| start_time = time.time() | |
| # Log request | |
| logger.info(f"Request: {request.method} {request.url}") | |
| try: | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| # Log response | |
| logger.info( | |
| f"Response: {response.status_code} - " | |
| f"Time: {process_time:.4f}s - " | |
| f"Path: {request.url.path}" | |
| ) | |
| return response | |
| except Exception as e: | |
| process_time = time.time() - start_time | |
| logger.error( | |
| f"Error: {str(e)} - " | |
| f"Time: {process_time:.4f}s - " | |
| f"Path: {request.url.path}" | |
| ) | |
| raise | |
| class RateLimitMiddleware(BaseHTTPMiddleware): | |
| """Simple rate limiting middleware""" | |
| def __init__(self, app, calls_per_minute: int = 60): | |
| super().__init__(app) | |
| self.calls_per_minute = calls_per_minute | |
| self.client_calls = {} | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| client_ip = request.client.host | |
| current_time = time.time() | |
| # Clean old entries | |
| self.client_calls = { | |
| ip: calls for ip, calls in self.client_calls.items() | |
| if any(call_time > current_time - 60 for call_time in calls) | |
| } | |
| # Check rate limit | |
| if client_ip in self.client_calls: | |
| recent_calls = [ | |
| call_time for call_time in self.client_calls[client_ip] | |
| if call_time > current_time - 60 | |
| ] | |
| if len(recent_calls) >= self.calls_per_minute: | |
| raise HTTPException( | |
| status_code=429, | |
| detail="Rate limit exceeded. Please try again later." | |
| ) | |
| self.client_calls[client_ip] = recent_calls + [current_time] | |
| else: | |
| self.client_calls[client_ip] = [current_time] | |
| return await call_next(request) | |
| class AuthenticationMiddleware(BaseHTTPMiddleware): | |
| """Middleware to protect endpoints with session authentication""" | |
| # Paths that don't require authentication | |
| PUBLIC_PATHS = [ | |
| "/", | |
| "/docs", | |
| "/redoc", | |
| "/openapi.json", | |
| "/health", | |
| "/auth/login", | |
| "/auth/status", | |
| ] | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| # Check if path is public | |
| path = request.url.path | |
| # Allow public paths | |
| if any(path.startswith(public_path) for public_path in self.PUBLIC_PATHS): | |
| return await call_next(request) | |
| # Check for session token | |
| session_token = request.cookies.get("session_token") | |
| if not session_token: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Authentication required" | |
| ) | |
| # Verify session | |
| from api.routers.auth import verify_session | |
| session_data = verify_session(session_token) | |
| if not session_data: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid or expired session" | |
| ) | |
| # Add user info to request state | |
| request.state.user = session_data.get("username") | |
| return await call_next(request) | |
| def get_cors_middleware_config(): | |
| """Get CORS middleware configuration""" | |
| import os | |
| # Get allowed origins from environment or use defaults | |
| allowed_origins = os.getenv("ALLOWED_ORIGINS", "").split(",") | |
| if not allowed_origins or allowed_origins == [""]: | |
| # Default to allowing Hugging Face Space and localhost | |
| allowed_origins = [ | |
| "https://moazx-api.hf.space", | |
| "http://localhost:8000", | |
| "http://127.0.0.1:8000" | |
| ] | |
| return { | |
| "allow_origins": allowed_origins, | |
| "allow_credentials": True, | |
| "allow_methods": ["*"], | |
| "allow_headers": ["*"], | |
| "expose_headers": ["*"], | |
| } | |