LeomordKaly's picture
deploy: phase 3 BYOK backend (Dockerfile.hf, FastAPI on 7860)
aba4f71 verified
"""Structured audit logging for compliance and security tracking.
Records all sensitive operations (queries, data access, ingestion events,
security blocks) with structured metadata. Persists entries to daily JSONL
files for later review, export to SIEM systems, or compliance reporting.
Each entry carries a SHA-256 ``entry_hash`` and the ``prev_hash`` of the
previous entry, forming a hash chain. Any in-place edit, deletion, or
re-ordering of past entries breaks the chain and is detected by
``verify_chain`` / ``scripts/verify_audit_chain.py``.
"""
from __future__ import annotations
import hashlib
import hmac
import json
import threading
from datetime import UTC, date, datetime
from pathlib import Path
from typing import Any
from pydantic import BaseModel, Field
from config.settings import settings
from utils.logging import get_logger
from utils.pii import redact_dict
_audit_log = get_logger("audit")
GENESIS_HASH = "GENESIS"
class AuditEntry(BaseModel):
"""A single structured audit log entry.
Attributes:
timestamp: UTC timestamp of the event.
action: Event category — "query", "upload", "access", "security_block", "inference".
user_id: Identifier of the user who triggered the event.
org_id: Organization identifier for multi-tenant tracking.
details: Action-specific details (query text, file path, etc.).
sensitivity_level: Data sensitivity classification (low, medium, high).
status: Outcome — "success", "blocked", or "error".
latency_ms: Operation latency in milliseconds (if applicable).
metadata: Additional unstructured metadata.
prev_hash: SHA-256 of the previous entry's ``entry_hash`` (``GENESIS`` for first).
entry_hash: SHA-256 over canonical JSON of this entry excluding ``entry_hash``.
"""
timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC))
action: str
user_id: str
org_id: str = ""
details: dict[str, Any] = Field(default_factory=dict)
sensitivity_level: str = "low"
status: str = "success"
latency_ms: float | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
prev_hash: str = ""
entry_hash: str = ""
def compute_hash(self) -> str:
"""Hash the canonical JSON of this entry (excluding ``entry_hash``).
SHA-256 by default (tamper-evident). When ``settings.audit_hmac_key`` is
set the digest is HMAC-SHA256 keyed by that secret (tamper-resistant) —
``verify_chain`` recomputes the same way, so flipping the key on a fresh
chain upgrades the integrity guarantee with no other code change.
"""
payload = self.model_dump(mode="json", exclude={"entry_hash"})
canonical = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
key = settings.audit_hmac_key
if key:
return hmac.new(key.encode("utf-8"), canonical, hashlib.sha256).hexdigest()
return hashlib.sha256(canonical).hexdigest()
class AuditLogger:
"""Structured audit trail for security-sensitive operations.
Persists entries to daily JSONL files and emits structured log events
that can be collected by log aggregators for compliance reporting.
Args:
log_dir: Directory path for audit log files. Created if not exists.
"""
def __init__(self, log_dir: str | None = None) -> None:
"""Initialize the audit logger with a target log directory.
Args:
log_dir: Directory path for storing daily JSONL audit files.
Defaults to ``settings.audit_log_dir`` so deployments can pin
an absolute path via the ``SAR_AUDIT_LOG_DIR`` env var.
"""
self._log_dir = Path(log_dir if log_dir is not None else settings.audit_log_dir)
self._log_dir.mkdir(parents=True, exist_ok=True)
# In-process mutex for hash-chain integrity. Multi-process deployments
# should serialise audit writes (single writer) or migrate to Postgres.
self._chain_lock = threading.Lock()
self._last_hash: str | None = None # Lazily resolved on first write
def _read_last_hash(self) -> str:
"""Scan audit directory to find the most recent entry's hash.
Used once on first write to bootstrap the chain. Returns ``GENESIS``
if no prior entries exist.
"""
try:
files = sorted(self._log_dir.glob("audit_*.jsonl"))
for path in reversed(files):
# Walk file backwards line by line — newest entry is last
with open(path, encoding="utf-8") as f:
lines = [ln for ln in f if ln.strip()]
if not lines:
continue
last = json.loads(lines[-1])
last_hash = last.get("entry_hash")
if last_hash:
return last_hash
return GENESIS_HASH
except Exception as exc:
_audit_log.warning("audit_chain_bootstrap_failed", error=str(exc))
return GENESIS_HASH
def log_query(
self,
*,
user_id: str,
org_id: str = "",
query: str,
response_summary: str = "",
sensitivity: str = "low",
status: str = "success",
latency_ms: float | None = None,
**kwargs: Any,
) -> AuditEntry:
"""Record a user query event.
Args:
user_id: Identifier of the user making the query.
org_id: Organization identifier.
query: The natural-language query text.
response_summary: Brief summary of the generated response.
sensitivity: Data sensitivity level.
status: Query outcome status.
latency_ms: Query processing time in milliseconds.
**kwargs: Additional metadata fields.
Returns:
The persisted AuditEntry.
"""
entry = AuditEntry(
action="query",
user_id=user_id,
org_id=org_id,
details={"query": query, "response_summary": response_summary},
sensitivity_level=sensitivity,
status=status,
latency_ms=latency_ms,
metadata=kwargs,
)
self._persist(entry)
_audit_log.info(
"query_executed",
user_id=user_id,
query_length=len(query),
status=status,
latency_ms=latency_ms,
)
return entry
def log_access(
self,
*,
user_id: str,
org_id: str = "",
documents_accessed: list[str] | None = None,
sensitivity: str = "low",
**kwargs: Any,
) -> AuditEntry:
"""Record a document access event.
Args:
user_id: Identifier of the requesting user.
org_id: Organization identifier.
documents_accessed: List of document IDs or names accessed.
sensitivity: Data sensitivity level.
**kwargs: Additional metadata fields.
Returns:
The persisted AuditEntry.
"""
entry = AuditEntry(
action="access",
user_id=user_id,
org_id=org_id,
details={"documents_accessed": documents_accessed or []},
sensitivity_level=sensitivity,
status="success",
metadata=kwargs,
)
self._persist(entry)
_audit_log.info(
"access_event",
user_id=user_id,
doc_count=len(documents_accessed or []),
sensitivity=sensitivity,
)
return entry
def log_ingestion(
self,
*,
user_id: str,
org_id: str = "",
file_path: str = "",
num_chunks: int = 0,
status: str = "success",
# Legacy support for existing pipeline calls
document_name: str = "",
chunk_count: int = 0,
**kwargs: Any,
) -> AuditEntry:
"""Record a document ingestion event.
Args:
user_id: Identifier of the user who triggered ingestion.
org_id: Organization identifier.
file_path: Path of the ingested document.
num_chunks: Number of chunks produced.
status: Ingestion outcome status.
document_name: Legacy parameter (alias for file_path).
chunk_count: Legacy parameter (alias for num_chunks).
**kwargs: Additional metadata fields.
Returns:
The persisted AuditEntry.
"""
# Support legacy call signature from pipeline.py
actual_path = file_path or document_name
actual_chunks = num_chunks or chunk_count
entry = AuditEntry(
action="upload",
user_id=user_id,
org_id=org_id,
details={"file_path": actual_path, "num_chunks": actual_chunks},
sensitivity_level=kwargs.pop("sensitivity", "low"),
status=status,
metadata=kwargs,
)
self._persist(entry)
_audit_log.info(
"document_ingested",
user_id=user_id,
file_path=actual_path,
chunk_count=actual_chunks,
status=status,
)
return entry
def log_feedback(
self,
*,
user_id: str,
org_id: str = "",
rating: str,
query: str = "",
**kwargs: Any,
) -> AuditEntry:
"""Record a user thumbs-up/down on an answer.
Lands on the same SHA-256 hash chain as every other event (PII-redacted
before persistence), so feedback is itself tamper-evident and exportable
via ``/byok/audit``. ``rating`` is normalised to ``up`` / ``down``.
Args:
user_id: Identifier of the user giving feedback.
org_id: Organization identifier.
rating: ``up`` or ``down``.
query: The question the rated answer responded to.
**kwargs: Extra metadata (e.g. answer_summary).
Returns:
The persisted AuditEntry.
"""
entry = AuditEntry(
action="feedback",
user_id=user_id,
org_id=org_id,
details={"rating": rating, "query": query},
status="success",
metadata=kwargs,
)
self._persist(entry)
_audit_log.info("feedback_recorded", user_id=user_id, rating=rating)
return entry
def log_security_event(
self,
*,
user_id: str,
org_id: str = "",
event_type: str,
details: dict[str, Any] | None = None,
**kwargs: Any,
) -> AuditEntry:
"""Record a security-relevant event (block, escalation, anomaly).
Args:
user_id: Identifier of the user involved.
org_id: Organization identifier.
event_type: Type of security event (e.g., "rbac_block", "injection_attempt").
details: Event-specific details.
**kwargs: Additional metadata fields.
Returns:
The persisted AuditEntry.
"""
entry = AuditEntry(
action="security_block",
user_id=user_id,
org_id=org_id,
details={"event_type": event_type, **(details or {})},
sensitivity_level="high",
status="blocked",
metadata=kwargs,
)
self._persist(entry)
_audit_log.warning(
"security_event",
user_id=user_id,
event_type=event_type,
org_id=org_id,
)
return entry
def log_access_legacy(
self,
*,
user_id: str,
resource: str,
action: str,
granted: bool,
metadata: dict[str, Any] | None = None,
) -> None:
"""Record a resource access event (legacy interface).
Maintains backward compatibility with the original AuditLogger API.
Args:
user_id: Identifier of the requesting user.
resource: The resource being accessed.
action: Action attempted (read, write, delete).
granted: Whether access was granted.
metadata: Additional context.
"""
entry = AuditEntry(
action="access",
user_id=user_id,
org_id="",
details={"resource": resource, "action": action, "granted": granted},
sensitivity_level="medium",
status="success" if granted else "blocked",
metadata=metadata or {},
)
self._persist(entry)
_audit_log.info(
"access_event",
user_id=user_id,
resource=resource,
action=action,
granted=granted,
)
def _persist(self, entry: AuditEntry) -> None:
"""Append an audit entry to the daily JSONL file with hash chaining.
File naming convention: ``audit_YYYY-MM-DD.jsonl``. Each entry's
``prev_hash`` references the previous entry's ``entry_hash`` so the
whole stream is a tamper-evident chain. Writes are serialised by
``self._chain_lock`` to prevent concurrent re-use of a ``prev_hash``.
Args:
entry: The AuditEntry to persist (hash fields populated here).
"""
try:
with self._chain_lock:
if self._last_hash is None:
self._last_hash = self._read_last_hash()
# Scrub PII *before* hashing so the on-disk entry and its
# signature match. Live in-memory state is unaffected because
# _persist is the boundary to durable storage.
entry.details = redact_dict(entry.details)
entry.metadata = redact_dict(entry.metadata)
entry.prev_hash = self._last_hash
entry.entry_hash = entry.compute_hash()
today = date.today().isoformat()
file_path = self._log_dir / f"audit_{today}.jsonl"
line = entry.model_dump_json() + "\n"
with open(file_path, "a", encoding="utf-8") as f:
f.write(line)
self._last_hash = entry.entry_hash
except Exception as exc:
_audit_log.error("audit_persist_failed", error=str(exc))
def verify_chain(
self,
start_date: date | str | None = None,
end_date: date | str | None = None,
) -> dict[str, Any]:
"""Verify the SHA-256 hash chain across one or more daily audit files.
Walks every entry in ``[start_date, end_date]`` (or all files when
both are ``None``) and recomputes ``entry_hash`` while checking that
each entry's ``prev_hash`` matches the previous entry's hash.
Args:
start_date: First date (inclusive). ``None`` for "earliest".
end_date: Last date (inclusive). ``None`` for "latest".
Returns:
Dict with keys ``valid`` (bool), ``checked`` (int), ``broken_at``
(list of file:line:reason strings), and ``last_hash`` (str).
"""
files = sorted(self._log_dir.glob("audit_*.jsonl"))
if start_date is not None:
if isinstance(start_date, str):
start_date = date.fromisoformat(start_date)
files = [p for p in files if date.fromisoformat(p.stem.split("_", 1)[1]) >= start_date]
if end_date is not None:
if isinstance(end_date, str):
end_date = date.fromisoformat(end_date)
files = [p for p in files if date.fromisoformat(p.stem.split("_", 1)[1]) <= end_date]
broken: list[str] = []
checked = 0
expected_prev = GENESIS_HASH if start_date is None else None
last_hash = GENESIS_HASH
for path in files:
with open(path, encoding="utf-8") as f:
for lineno, raw in enumerate(f, start=1):
raw = raw.strip()
if not raw:
continue
try:
entry = AuditEntry.model_validate_json(raw)
except Exception as exc:
broken.append(f"{path.name}:{lineno}:parse_error:{exc}")
return {
"valid": False,
"checked": checked,
"broken_at": broken,
"last_hash": last_hash,
}
recomputed = entry.compute_hash()
if recomputed != entry.entry_hash:
broken.append(
f"{path.name}:{lineno}:hash_mismatch:"
f"stored={entry.entry_hash[:12]} recomputed={recomputed[:12]}"
)
if expected_prev is not None and entry.prev_hash != expected_prev:
broken.append(
f"{path.name}:{lineno}:chain_broken:"
f"prev_hash={entry.prev_hash[:12]} expected={expected_prev[:12]}"
)
expected_prev = entry.entry_hash
last_hash = entry.entry_hash
checked += 1
return {
"valid": not broken,
"checked": checked,
"broken_at": broken,
"last_hash": last_hash,
}
def get_entries(
self,
start_date: date | str,
end_date: date | str,
user_id: str | None = None,
action: str | None = None,
) -> list[AuditEntry]:
"""Read and filter audit entries from persisted JSONL files.
Args:
start_date: Start date (inclusive) for the query range.
end_date: End date (inclusive) for the query range.
user_id: Optional filter by user identifier.
action: Optional filter by action type.
Returns:
List of matching AuditEntry objects.
"""
if isinstance(start_date, str):
start_date = date.fromisoformat(start_date)
if isinstance(end_date, str):
end_date = date.fromisoformat(end_date)
entries: list[AuditEntry] = []
current = start_date
while current <= end_date:
file_path = self._log_dir / f"audit_{current.isoformat()}.jsonl"
if file_path.exists():
try:
with open(file_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
entry = AuditEntry.model_validate_json(line)
# Apply filters
if user_id and entry.user_id != user_id:
continue
if action and entry.action != action:
continue
entries.append(entry)
except Exception as exc:
_audit_log.error(
"audit_read_failed",
file=str(file_path),
error=str(exc),
)
# Advance to next day
from datetime import timedelta
current = current + timedelta(days=1)
return entries
def get_summary(
self,
start_date: date | str,
end_date: date | str,
) -> dict[str, Any]:
"""Generate aggregate summary of audit entries over a date range.
Args:
start_date: Start date (inclusive).
end_date: End date (inclusive).
Returns:
Dictionary with counts grouped by action, user, and status.
"""
entries = self.get_entries(start_date, end_date)
by_action: dict[str, int] = {}
by_user: dict[str, int] = {}
by_status: dict[str, int] = {}
for entry in entries:
by_action[entry.action] = by_action.get(entry.action, 0) + 1
by_user[entry.user_id] = by_user.get(entry.user_id, 0) + 1
by_status[entry.status] = by_status.get(entry.status, 0) + 1
return {
"total_entries": len(entries),
"by_action": by_action,
"by_user": by_user,
"by_status": by_status,
"date_range": {
"start": str(start_date),
"end": str(end_date),
},
}
# Module-level singleton
audit_logger = AuditLogger()