Lung-Cancer-AI-Advisor / api /middleware.py
moazx's picture
Enhance API security and functionality by adding authentication middleware and session management. Updated app.py to include the new auth router and integrated authentication checks for protected endpoints. Modified requirements.txt to include necessary libraries for session handling. Updated .env.example to include authentication credentials. Improved retrieval functions with query expansion for better medical term matching and enriched context in responses.
ddc9c77
raw
history blame
4.7 kB
"""
Middleware for Medical RAG AI Advisor API
"""
import time
import logging
from typing import Callable, Awaitable, Optional
from fastapi import Request, Response, HTTPException, Cookie
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class ProcessTimeMiddleware(BaseHTTPMiddleware):
"""Middleware to add processing time to response headers"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = f"{process_time:.4f}"
return response
class LoggingMiddleware(BaseHTTPMiddleware):
"""Middleware for request/response logging"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
# Log request
logger.info(f"Request: {request.method} {request.url}")
try:
response = await call_next(request)
process_time = time.time() - start_time
# Log response
logger.info(
f"Response: {response.status_code} - "
f"Time: {process_time:.4f}s - "
f"Path: {request.url.path}"
)
return response
except Exception as e:
process_time = time.time() - start_time
logger.error(
f"Error: {str(e)} - "
f"Time: {process_time:.4f}s - "
f"Path: {request.url.path}"
)
raise
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Simple rate limiting middleware"""
def __init__(self, app, calls_per_minute: int = 60):
super().__init__(app)
self.calls_per_minute = calls_per_minute
self.client_calls = {}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
client_ip = request.client.host
current_time = time.time()
# Clean old entries
self.client_calls = {
ip: calls for ip, calls in self.client_calls.items()
if any(call_time > current_time - 60 for call_time in calls)
}
# Check rate limit
if client_ip in self.client_calls:
recent_calls = [
call_time for call_time in self.client_calls[client_ip]
if call_time > current_time - 60
]
if len(recent_calls) >= self.calls_per_minute:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded. Please try again later."
)
self.client_calls[client_ip] = recent_calls + [current_time]
else:
self.client_calls[client_ip] = [current_time]
return await call_next(request)
class AuthenticationMiddleware(BaseHTTPMiddleware):
"""Middleware to protect endpoints with session authentication"""
# Paths that don't require authentication
PUBLIC_PATHS = [
"/",
"/docs",
"/redoc",
"/openapi.json",
"/health",
"/auth/login",
"/auth/status",
]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Check if path is public
path = request.url.path
# Allow public paths
if any(path.startswith(public_path) for public_path in self.PUBLIC_PATHS):
return await call_next(request)
# Check for session token
session_token = request.cookies.get("session_token")
if not session_token:
raise HTTPException(
status_code=401,
detail="Authentication required"
)
# Verify session
from api.routers.auth import verify_session
session_data = verify_session(session_token)
if not session_data:
raise HTTPException(
status_code=401,
detail="Invalid or expired session"
)
# Add user info to request state
request.state.user = session_data.get("username")
return await call_next(request)
def get_cors_middleware_config():
"""Get CORS middleware configuration"""
return {
"allow_origins": ["*"], # Configure appropriately for production
"allow_credentials": True,
"allow_methods": ["*"],
"allow_headers": ["*"],
}