|
|
""" |
|
|
Middleware for Lung Cancer AI Advisor API |
|
|
""" |
|
|
import time |
|
|
import logging |
|
|
from typing import Callable, Awaitable, Optional |
|
|
from fastapi import Request, Response, HTTPException, Cookie |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ProcessTimeMiddleware(BaseHTTPMiddleware): |
|
|
"""Middleware to add processing time to response headers""" |
|
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response: |
|
|
start_time = time.time() |
|
|
response = await call_next(request) |
|
|
process_time = time.time() - start_time |
|
|
response.headers["X-Process-Time"] = f"{process_time:.4f}" |
|
|
return response |
|
|
|
|
|
|
|
|
class LoggingMiddleware(BaseHTTPMiddleware): |
|
|
"""Middleware for request/response logging""" |
|
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
logger.info(f"Request: {request.method} {request.url}") |
|
|
|
|
|
try: |
|
|
response = await call_next(request) |
|
|
process_time = time.time() - start_time |
|
|
|
|
|
|
|
|
logger.info( |
|
|
f"Response: {response.status_code} - " |
|
|
f"Time: {process_time:.4f}s - " |
|
|
f"Path: {request.url.path}" |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
process_time = time.time() - start_time |
|
|
logger.error( |
|
|
f"Error: {str(e)} - " |
|
|
f"Time: {process_time:.4f}s - " |
|
|
f"Path: {request.url.path}" |
|
|
) |
|
|
raise |
|
|
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware): |
|
|
"""Simple rate limiting middleware""" |
|
|
|
|
|
def __init__(self, app, calls_per_minute: int = 60): |
|
|
super().__init__(app) |
|
|
self.calls_per_minute = calls_per_minute |
|
|
self.client_calls = {} |
|
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response: |
|
|
client_ip = request.client.host |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
self.client_calls = { |
|
|
ip: calls for ip, calls in self.client_calls.items() |
|
|
if any(call_time > current_time - 60 for call_time in calls) |
|
|
} |
|
|
|
|
|
|
|
|
if client_ip in self.client_calls: |
|
|
recent_calls = [ |
|
|
call_time for call_time in self.client_calls[client_ip] |
|
|
if call_time > current_time - 60 |
|
|
] |
|
|
if len(recent_calls) >= self.calls_per_minute: |
|
|
raise HTTPException( |
|
|
status_code=429, |
|
|
detail="Rate limit exceeded. Please try again later." |
|
|
) |
|
|
self.client_calls[client_ip] = recent_calls + [current_time] |
|
|
else: |
|
|
self.client_calls[client_ip] = [current_time] |
|
|
|
|
|
return await call_next(request) |
|
|
|
|
|
|
|
|
class AuthenticationMiddleware(BaseHTTPMiddleware): |
|
|
"""Middleware to protect endpoints with session authentication""" |
|
|
|
|
|
|
|
|
PUBLIC_PATHS = [ |
|
|
"/", |
|
|
"/docs", |
|
|
"/redoc", |
|
|
"/openapi.json", |
|
|
"/health", |
|
|
"/auth/login", |
|
|
"/auth/status", |
|
|
] |
|
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response: |
|
|
|
|
|
path = request.url.path |
|
|
|
|
|
|
|
|
if any(path.startswith(public_path) for public_path in self.PUBLIC_PATHS): |
|
|
return await call_next(request) |
|
|
|
|
|
|
|
|
session_token = request.cookies.get("session_token") |
|
|
|
|
|
if not session_token: |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="Authentication required" |
|
|
) |
|
|
|
|
|
|
|
|
from api.routers.auth import verify_session |
|
|
session_data = verify_session(session_token) |
|
|
|
|
|
if not session_data: |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="Invalid or expired session" |
|
|
) |
|
|
|
|
|
|
|
|
request.state.user = session_data.get("username") |
|
|
|
|
|
return await call_next(request) |
|
|
|
|
|
|
|
|
def get_cors_middleware_config(): |
|
|
"""Get CORS middleware configuration""" |
|
|
import os |
|
|
|
|
|
|
|
|
allowed_origins = os.getenv("ALLOWED_ORIGINS", "").split(",") |
|
|
if not allowed_origins or allowed_origins == [""]: |
|
|
|
|
|
|
|
|
allowed_origins = [ |
|
|
"https://moazx-api.hf.space", |
|
|
"http://localhost:8000", |
|
|
"http://127.0.0.1:8000", |
|
|
"http://localhost:5500", |
|
|
"http://127.0.0.1:5500", |
|
|
"http://localhost:3000", |
|
|
"http://127.0.0.1:3000", |
|
|
"null" |
|
|
] |
|
|
|
|
|
return { |
|
|
"allow_origins": allowed_origins, |
|
|
"allow_credentials": True, |
|
|
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"], |
|
|
"allow_headers": ["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With", "Cookie"], |
|
|
"expose_headers": ["Set-Cookie"], |
|
|
} |
|
|
|