Spaces:
Sleeping
Sleeping
| """ | |
| Middleware for Medical RAG AI Advisor API | |
| """ | |
| import time | |
| import logging | |
| from typing import Callable, Awaitable | |
| from fastapi import Request, Response, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| logger = logging.getLogger(__name__) | |
| class ProcessTimeMiddleware(BaseHTTPMiddleware): | |
| """Middleware to add processing time to response headers""" | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| response.headers["X-Process-Time"] = f"{process_time:.4f}" | |
| return response | |
| class LoggingMiddleware(BaseHTTPMiddleware): | |
| """Middleware for request/response logging""" | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| start_time = time.time() | |
| # Log request | |
| logger.info(f"Request: {request.method} {request.url}") | |
| try: | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| # Log response | |
| logger.info( | |
| f"Response: {response.status_code} - " | |
| f"Time: {process_time:.4f}s - " | |
| f"Path: {request.url.path}" | |
| ) | |
| return response | |
| except Exception as e: | |
| process_time = time.time() - start_time | |
| logger.error( | |
| f"Error: {str(e)} - " | |
| f"Time: {process_time:.4f}s - " | |
| f"Path: {request.url.path}" | |
| ) | |
| raise | |
| class RateLimitMiddleware(BaseHTTPMiddleware): | |
| """Simple rate limiting middleware""" | |
| def __init__(self, app, calls_per_minute: int = 60): | |
| super().__init__(app) | |
| self.calls_per_minute = calls_per_minute | |
| self.client_calls = {} | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| client_ip = request.client.host | |
| current_time = time.time() | |
| # Clean old entries | |
| self.client_calls = { | |
| ip: calls for ip, calls in self.client_calls.items() | |
| if any(call_time > current_time - 60 for call_time in calls) | |
| } | |
| # Check rate limit | |
| if client_ip in self.client_calls: | |
| recent_calls = [ | |
| call_time for call_time in self.client_calls[client_ip] | |
| if call_time > current_time - 60 | |
| ] | |
| if len(recent_calls) >= self.calls_per_minute: | |
| raise HTTPException( | |
| status_code=429, | |
| detail="Rate limit exceeded. Please try again later." | |
| ) | |
| self.client_calls[client_ip] = recent_calls + [current_time] | |
| else: | |
| self.client_calls[client_ip] = [current_time] | |
| return await call_next(request) | |
| def get_cors_middleware_config(): | |
| """Get CORS middleware configuration""" | |
| return { | |
| "allow_origins": ["*"], # Configure appropriately for production | |
| "allow_credentials": True, | |
| "allow_methods": ["*"], | |
| "allow_headers": ["*"], | |
| } | |