digital-twin / security.py
LSmithPMP's picture
Deploy: security-hardened RAG digital twin
6ca927d
Raw
History Blame Contribute Delete
8.13 kB
"""
Security module for the Digital Twin application.
Implements defense-in-depth controls: input validation, output filtering,
rate limiting, and content safety checks.
Security design principle: All controls are enforced at the application layer,
independent of LLM behavior or platform-level protections.
"""
import logging
import re
import time
from collections import defaultdict
import config
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Input Validation
# ---------------------------------------------------------------------------
# Maximum characters allowed per user message
MAX_INPUT_LENGTH = 2000
# Maximum conversation turns before session reset is recommended
MAX_CONVERSATION_TURNS = 50
# Patterns commonly used in prompt injection attempts
_INJECTION_PATTERNS = [
r"ignore\s+(all\s+)?(previous|prior|above)\s+(instructions|prompts|rules)",
r"you\s+are\s+now\s+(a|an|the)\s+",
r"system\s*prompt",
r"reveal\s+(your|the)\s+(instructions|prompt|rules|system)",
r"act\s+as\s+(a|an|if)\s+",
r"pretend\s+(you\s+are|to\s+be)",
r"forget\s+(everything|all|your\s+instructions)",
r"override\s+(your|the|all)\s+",
r"<\s*/?\s*system\s*>",
r"<\s*/?\s*developer\s*>",
r"\[\s*INST\s*\]",
r"\[\s*/\s*INST\s*\]",
]
_COMPILED_PATTERNS = [re.compile(p, re.IGNORECASE) for p in _INJECTION_PATTERNS]
def validate_input(user_input: str) -> tuple[bool, str]:
"""
Validate and sanitize user input before processing.
Returns:
(is_valid, message): If invalid, message contains the reason.
"""
# Length check
if not user_input or not user_input.strip():
return False, "empty_input"
if len(user_input) > MAX_INPUT_LENGTH:
return False, "input_too_long"
# Prompt injection pattern detection
for pattern in _COMPILED_PATTERNS:
if pattern.search(user_input):
logger.warning("Prompt injection pattern detected in user input")
return False, "injection_detected"
return True, "valid"
def sanitize_input(user_input: str) -> str:
"""
Sanitize user input by stripping control characters and normalizing whitespace.
Applied after validation passes.
"""
# Remove null bytes and control characters (keep newlines and tabs)
sanitized = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', user_input)
# Normalize excessive whitespace
sanitized = re.sub(r'\n{3,}', '\n\n', sanitized)
return sanitized.strip()
# ---------------------------------------------------------------------------
# Output Filtering
# ---------------------------------------------------------------------------
# Phrases that should never appear in model output (architecture disclosure)
_OUTPUT_BLOCKLIST = [
"system prompt",
"system message",
"developer role",
"retrieved_context",
"chromadb",
"chroma_path",
"biography.txt",
"build_vectors",
"rag.py",
"inference.py",
"prompts.py",
"config.py",
"tools.py",
"app.py",
"security.py",
"openai_api_key",
"pushover_user",
"pushover_token",
"hf_token",
".env",
"gr.state",
"api_messages",
"tool_registry",
"distance_threshold",
"n_results",
"bm25_index",
"bm25index",
"bm25 index",
"section_filter",
"neighbor_window",
"neighbor expansion",
"global_idx",
"context_injection",
"chunk source=",
"section name=",
"_prune_stale",
"max_context_chunks",
"max_retained_injections",
"query routing",
"hybrid search",
"rerank",
]
def filter_output(text: str) -> str:
"""
Scan model output for inadvertent disclosure of system architecture,
internal file names, or secret references. Replace flagged content
with a safe in-character response.
"""
text_lower = text.lower()
for phrase in _OUTPUT_BLOCKLIST:
if phrase in text_lower:
logger.warning("Output filter triggered on phrase: %s", phrase)
return ("I appreciate the question, but I'm not able to share details about "
"my internal architecture. Feel free to ask me about my research, "
"career, or professional interests instead.")
return text
# ---------------------------------------------------------------------------
# Rate Limiting (Per-Session)
# ---------------------------------------------------------------------------
class SessionRateLimiter:
"""
Token-bucket rate limiter scoped to individual sessions.
Prevents query flooding and tool abuse at the application layer.
"""
def __init__(self, max_queries_per_minute: int = 10, max_notifications_per_hour: int = 5):
self._query_timestamps: defaultdict[str, list[float]] = defaultdict(list)
self._notification_timestamps: defaultdict[str, list[float]] = defaultdict(list)
self._max_qpm = max_queries_per_minute
self._max_nph = max_notifications_per_hour
def check_query_rate(self, session_id: str = "default") -> bool:
"""Return True if query is within rate limit."""
now = time.time()
window = [t for t in self._query_timestamps[session_id] if now - t < 60]
self._query_timestamps[session_id] = window
if len(window) >= self._max_qpm:
logger.warning("Query rate limit exceeded for session %s", session_id[:8])
return False
self._query_timestamps[session_id].append(now)
return True
def check_notification_rate(self, session_id: str = "default") -> bool:
"""Return True if notification is within rate limit."""
now = time.time()
window = [t for t in self._notification_timestamps[session_id] if now - t < 3600]
self._notification_timestamps[session_id] = window
if len(window) >= self._max_nph:
logger.warning("Notification rate limit exceeded for session %s", session_id[:8])
return False
self._notification_timestamps[session_id].append(now)
return True
# Singleton rate limiter instance
rate_limiter = SessionRateLimiter()
# ---------------------------------------------------------------------------
# Conversation Depth Guard
# ---------------------------------------------------------------------------
def check_conversation_depth(api_messages: list) -> bool:
"""
Check if conversation has exceeded the maximum safe depth.
Deep conversations increase context window exposure and extraction risk.
Returns True if within safe limits.
"""
user_msg_count = sum(
1 for m in api_messages
if isinstance(m, dict) and m.get('role') == 'user'
)
if user_msg_count >= MAX_CONVERSATION_TURNS:
logger.info("Conversation depth limit reached (%d turns)", user_msg_count)
return False
return True
# ---------------------------------------------------------------------------
# Startup Security Audit
# ---------------------------------------------------------------------------
def audit_startup_security() -> list[str]:
"""
Run security checks at application startup.
Returns a list of warnings (empty list = all clear).
"""
import os
warnings = []
# Check that required secrets are present
if not os.environ.get("OPENAI_API_KEY"):
warnings.append("CRITICAL: OPENAI_API_KEY not set in environment")
# Check that .env is not accidentally committed
env_path = config.BASE_DIR / '.env'
gitignore_path = config.BASE_DIR / '.gitignore'
if env_path.exists() and gitignore_path.exists():
gitignore_content = gitignore_path.read_text()
if '.env' not in gitignore_content:
warnings.append("WARNING: .env exists but is not in .gitignore")
# Check ChromaDB telemetry is disabled
if config.CHROMA_CLIENT_SETTINGS.anonymized_telemetry is not False:
warnings.append("WARNING: ChromaDB telemetry is not disabled")
for w in warnings:
logger.warning("Startup security audit: %s", w)
return warnings