Spaces:
Running
Running
| """ | |
| MediGuard AI — Production Middlewares | |
| HIPAA-aware audit logging, request timing, and security headers. | |
| Designed for medical applications requiring compliance patterns. | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import logging | |
| import time | |
| import uuid | |
| from collections.abc import Callable | |
| from datetime import UTC, datetime | |
| from typing import Any | |
| from fastapi import Request, Response | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| logger = logging.getLogger("mediguard.audit") | |
| # --------------------------------------------------------------------------- | |
| # HIPAA Audit Logger | |
| # --------------------------------------------------------------------------- | |
| # Sensitive fields that should NEVER be logged | |
| SENSITIVE_FIELDS = { | |
| "biomarkers", | |
| "patient_context", | |
| "patient_id", | |
| "age", | |
| "gender", | |
| "bmi", | |
| "ssn", | |
| "mrn", | |
| "name", | |
| "address", | |
| "phone", | |
| "email", | |
| "dob", | |
| "date_of_birth", | |
| } | |
| # Endpoints that require audit logging | |
| AUDITABLE_ENDPOINTS = { | |
| "/analyze/natural", | |
| "/analyze/structured", | |
| "/ask", | |
| "/ask/stream", | |
| "/search", | |
| } | |
| def _hash_sensitive(value: str) -> str: | |
| """Create a one-way hash of sensitive data for audit trail without logging PHI.""" | |
| return f"sha256:{hashlib.sha256(value.encode()).hexdigest()[:16]}" | |
| def _redact_body(body_dict: dict) -> dict: | |
| """Redact sensitive fields from request body for logging.""" | |
| redacted = {} | |
| for key, value in body_dict.items(): | |
| if key.lower() in SENSITIVE_FIELDS: | |
| if isinstance(value, dict): | |
| redacted[key] = f"[REDACTED: {len(value)} fields]" | |
| elif isinstance(value, str): | |
| redacted[key] = f"[REDACTED: {len(value)} chars]" | |
| else: | |
| redacted[key] = "[REDACTED]" | |
| else: | |
| redacted[key] = value | |
| return redacted | |
| class HIPAAAuditMiddleware(BaseHTTPMiddleware): | |
| """ | |
| HIPAA-compliant audit logging middleware. | |
| Features: | |
| - Generates unique request IDs for traceability | |
| - Logs request metadata WITHOUT PHI/biomarker values | |
| - Creates audit trail for all medical analysis requests | |
| - Tracks request timing and response status | |
| - Hashes sensitive identifiers for correlation | |
| Audit logs are structured JSON for easy SIEM integration. | |
| """ | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| # Generate request ID | |
| request_id = f"req_{uuid.uuid4().hex[:12]}" | |
| request.state.request_id = request_id | |
| # Start timing | |
| start_time = time.time() | |
| # Extract metadata safely | |
| path = request.url.path | |
| method = request.method | |
| client_ip = request.client.host if request.client else "unknown" | |
| user_agent = request.headers.get("user-agent", "unknown")[:100] | |
| # Check if this endpoint needs audit logging | |
| needs_audit = any(path.startswith(ep) for ep in AUDITABLE_ENDPOINTS) | |
| # Pre-request audit entry | |
| audit_entry: dict[str, Any] = { | |
| "event": "request_start", | |
| "timestamp": datetime.now(UTC).isoformat(), | |
| "request_id": request_id, | |
| "method": method, | |
| "path": path, | |
| "client_ip_hash": _hash_sensitive(client_ip), | |
| "user_agent_hash": _hash_sensitive(user_agent), | |
| } | |
| # Try to read request body for POST requests (without logging PHI) | |
| if needs_audit and method == "POST": | |
| try: | |
| body = await request.body() | |
| # Store body for re-reading by route handlers | |
| request._body = body | |
| if body: | |
| body_dict = json.loads(body) | |
| redacted = _redact_body(body_dict) | |
| audit_entry["request_fields"] = list(redacted.keys()) | |
| # Log presence of biomarkers without values | |
| if "biomarkers" in body_dict: | |
| audit_entry["biomarker_count"] = ( | |
| len(body_dict["biomarkers"]) if isinstance(body_dict["biomarkers"], dict) else 1 | |
| ) | |
| except Exception as exc: | |
| logger.debug("Failed to audit POST body: %s", exc) | |
| if needs_audit: | |
| logger.info("AUDIT_REQUEST: %s", json.dumps(audit_entry)) | |
| # Process request | |
| response: Response = await call_next(request) | |
| # Post-request audit | |
| elapsed_ms = (time.time() - start_time) * 1000 | |
| completion_entry = { | |
| "event": "request_complete", | |
| "timestamp": datetime.now(UTC).isoformat(), | |
| "request_id": request_id, | |
| "method": method, | |
| "path": path, | |
| "status_code": response.status_code, | |
| "elapsed_ms": round(elapsed_ms, 2), | |
| } | |
| if needs_audit: | |
| logger.info("AUDIT_COMPLETE: %s", json.dumps(completion_entry)) | |
| # Add request ID to response headers | |
| response.headers["X-Request-ID"] = request_id | |
| response.headers["X-Response-Time"] = f"{elapsed_ms:.2f}ms" | |
| return response | |
| class SecurityHeadersMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Add security headers for HIPAA compliance. | |
| """ | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| response: Response = await call_next(request) | |
| # Security headers | |
| response.headers["X-Content-Type-Options"] = "nosniff" | |
| response.headers["X-Frame-Options"] = "DENY" | |
| response.headers["X-XSS-Protection"] = "1; mode=block" | |
| response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" | |
| response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate" | |
| response.headers["Pragma"] = "no-cache" | |
| # Medical data should never be cached | |
| if any(ep in request.url.path for ep in AUDITABLE_ENDPOINTS): | |
| response.headers["Cache-Control"] = "no-store, private" | |
| return response | |