jebin2's picture
Phase 2: Implement Audit Middleware
1c302c7
"""
Audit Middleware - Automatic request/response logging
Automatically logs all API requests and responses to AuditLog table.
Similar to CreditMiddleware pattern.
"""
import time
import logging
from typing import Optional
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from core.database import async_session_maker
from services.audit_service.config import AuditServiceConfig
logger = logging.getLogger(__name__)
class AuditMiddleware(BaseHTTPMiddleware):
"""
Middleware for automatic audit logging.
Logs all requests/responses unless excluded in configuration.
"""
def __init__(self, app: ASGIApp):
super().__init__(app)
async def dispatch(self, request: Request, call_next):
"""
Process request and automatically log.
Flow:
1. Check if path should be logged
2. Capture request metadata
3. Process request
4. Log based on response
"""
path = request.url.path
method = request.method
# Skip excluded paths
if AuditServiceConfig.is_excluded(path):
return await call_next(request)
# Skip OPTIONS requests (CORS preflight)
if method == "OPTIONS":
return await call_next(request)
# Capture request start time
start_time = time.time()
# Get user if authenticated (set by AuthMiddleware)
user = getattr(request.state, 'user', None)
user_id = user.id if user else None
# Process request
response = await call_next(request)
# Calculate duration
duration_ms = (time.time() - start_time) * 1000
# Determine if we should log
if AuditServiceConfig.should_log(path, response.status_code):
# Log asynchronously (don't block response)
try:
await self._log_request(
request=request,
response=response,
user_id=user_id,
duration_ms=duration_ms
)
except Exception as e:
# Don't fail request if logging fails
logger.error(f"Failed to log request: {e}", exc_info=True)
return response
async def _log_request(
self,
request: Request,
response: Response,
user_id: Optional[int],
duration_ms: float
):
"""Log request to database."""
from services.audit_service import AuditService
# Determine action from method + path
action = f"{request.method}:{request.url.path}"
# Determine status
if response.status_code < 400:
status = "success"
else:
status = "failure"
# Get log type from config
log_type = AuditServiceConfig.get_log_type(request.url.path)
# Build details
details = {
"method": request.method,
"path": str(request.url.path),
"query_params": dict(request.query_params),
"status_code": response.status_code,
"duration_ms": round(duration_ms, 2)
}
# Add response body if configured (privacy risk!)
if AuditServiceConfig._log_response_bodies:
# Note: This requires streaming the response body
# For now, skip this to avoid complexity
pass
# Create database session and log
async with async_session_maker() as db:
try:
await AuditService.log_event(
db=db,
action=action,
status=status,
user_id=user_id,
client_user_id=None,
details=details,
request=request,
log_type=log_type
)
await db.commit()
except Exception as e:
logger.error(f"Failed to commit audit log: {e}")
await db.rollback()