Agentic-RagBot / src /middlewares.py
T0X1N's picture
chore: codebase audit and fixes (ruff, mypy, pytest)
9659593
"""
MediGuard AI — Production Middlewares
HIPAA-aware audit logging, request timing, and security headers.
Designed for medical applications requiring compliance patterns.
"""
from __future__ import annotations
import hashlib
import json
import logging
import time
import uuid
from collections.abc import Callable
from datetime import UTC, datetime
from typing import Any
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger("mediguard.audit")
# ---------------------------------------------------------------------------
# HIPAA Audit Logger
# ---------------------------------------------------------------------------
# Sensitive fields that should NEVER be logged
SENSITIVE_FIELDS = {
"biomarkers",
"patient_context",
"patient_id",
"age",
"gender",
"bmi",
"ssn",
"mrn",
"name",
"address",
"phone",
"email",
"dob",
"date_of_birth",
}
# Endpoints that require audit logging
AUDITABLE_ENDPOINTS = {
"/analyze/natural",
"/analyze/structured",
"/ask",
"/ask/stream",
"/search",
}
def _hash_sensitive(value: str) -> str:
"""Create a one-way hash of sensitive data for audit trail without logging PHI."""
return f"sha256:{hashlib.sha256(value.encode()).hexdigest()[:16]}"
def _redact_body(body_dict: dict) -> dict:
"""Redact sensitive fields from request body for logging."""
redacted = {}
for key, value in body_dict.items():
if key.lower() in SENSITIVE_FIELDS:
if isinstance(value, dict):
redacted[key] = f"[REDACTED: {len(value)} fields]"
elif isinstance(value, str):
redacted[key] = f"[REDACTED: {len(value)} chars]"
else:
redacted[key] = "[REDACTED]"
else:
redacted[key] = value
return redacted
class HIPAAAuditMiddleware(BaseHTTPMiddleware):
"""
HIPAA-compliant audit logging middleware.
Features:
- Generates unique request IDs for traceability
- Logs request metadata WITHOUT PHI/biomarker values
- Creates audit trail for all medical analysis requests
- Tracks request timing and response status
- Hashes sensitive identifiers for correlation
Audit logs are structured JSON for easy SIEM integration.
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Generate request ID
request_id = f"req_{uuid.uuid4().hex[:12]}"
request.state.request_id = request_id
# Start timing
start_time = time.time()
# Extract metadata safely
path = request.url.path
method = request.method
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")[:100]
# Check if this endpoint needs audit logging
needs_audit = any(path.startswith(ep) for ep in AUDITABLE_ENDPOINTS)
# Pre-request audit entry
audit_entry: dict[str, Any] = {
"event": "request_start",
"timestamp": datetime.now(UTC).isoformat(),
"request_id": request_id,
"method": method,
"path": path,
"client_ip_hash": _hash_sensitive(client_ip),
"user_agent_hash": _hash_sensitive(user_agent),
}
# Try to read request body for POST requests (without logging PHI)
if needs_audit and method == "POST":
try:
body = await request.body()
# Store body for re-reading by route handlers
request._body = body
if body:
body_dict = json.loads(body)
redacted = _redact_body(body_dict)
audit_entry["request_fields"] = list(redacted.keys())
# Log presence of biomarkers without values
if "biomarkers" in body_dict:
audit_entry["biomarker_count"] = (
len(body_dict["biomarkers"]) if isinstance(body_dict["biomarkers"], dict) else 1
)
except Exception as exc:
logger.debug("Failed to audit POST body: %s", exc)
if needs_audit:
logger.info("AUDIT_REQUEST: %s", json.dumps(audit_entry))
# Process request
response: Response = await call_next(request)
# Post-request audit
elapsed_ms = (time.time() - start_time) * 1000
completion_entry = {
"event": "request_complete",
"timestamp": datetime.now(UTC).isoformat(),
"request_id": request_id,
"method": method,
"path": path,
"status_code": response.status_code,
"elapsed_ms": round(elapsed_ms, 2),
}
if needs_audit:
logger.info("AUDIT_COMPLETE: %s", json.dumps(completion_entry))
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
response.headers["X-Response-Time"] = f"{elapsed_ms:.2f}ms"
return response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
Add security headers for HIPAA compliance.
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
response: Response = await call_next(request)
# Security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
response.headers["Pragma"] = "no-cache"
# Medical data should never be cached
if any(ep in request.url.path for ep in AUDITABLE_ENDPOINTS):
response.headers["Cache-Control"] = "no-store, private"
return response