|
|
""" |
|
|
Base application class for Secure AI Agents Suite |
|
|
Provides common functionality for all agents including security, logging, and MCP integration. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import hashlib |
|
|
import logging |
|
|
import re |
|
|
from datetime import datetime |
|
|
from typing import Dict, List, Optional, Any |
|
|
from abc import ABC, abstractmethod |
|
|
import aiohttp |
|
|
import asyncio |
|
|
|
|
|
|
|
|
class SecurityMiddleware: |
|
|
"""Handles security features: prompt injection defense, output sanitization, audit logging.""" |
|
|
|
|
|
def __init__(self, config: Dict[str, Any]): |
|
|
self.config = config |
|
|
self.audit_logger = logging.getLogger(f"{__name__}.audit") |
|
|
|
|
|
|
|
|
self.injection_patterns = [ |
|
|
r"ignore\s+previous\s+instructions", |
|
|
r"forget\s+everything\s+above", |
|
|
r"system\s*:\s*ignore", |
|
|
r"<system>", |
|
|
r"<?xml", |
|
|
r"<script", |
|
|
r"javascript:", |
|
|
r"eval\(", |
|
|
r"exec\(", |
|
|
r"__import__", |
|
|
r"subprocess", |
|
|
r"os\.system", |
|
|
r"shell\s*=\s*True" |
|
|
] |
|
|
|
|
|
def detect_prompt_injection(self, prompt: str) -> bool: |
|
|
"""Detect potential prompt injection attempts.""" |
|
|
prompt_lower = prompt.lower() |
|
|
for pattern in self.injection_patterns: |
|
|
if re.search(pattern, prompt_lower): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def sanitize_output(self, content: str) -> str: |
|
|
"""Sanitize output to remove sensitive information.""" |
|
|
|
|
|
patterns = [ |
|
|
(r'\b\d{16}\b', '[CREDIT_CARD_MASKED]'), |
|
|
(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN_MASKED]'), |
|
|
(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL_MASKED]'), |
|
|
(r'\b(?:\+?1[-.\s]?)?\(?[0-9]{3}\)?[-.\s]?[0-9]{3}[-.\s]?[0-9]{4}\b', '[PHONE_MASKED]'), |
|
|
] |
|
|
|
|
|
sanitized = content |
|
|
for pattern, replacement in patterns: |
|
|
sanitized = re.sub(pattern, replacement, sanitized) |
|
|
|
|
|
return sanitized |
|
|
|
|
|
def log_mcp_call(self, session_id: str, tool: str, request: Dict, response: Dict): |
|
|
"""Log MCP calls for audit purposes.""" |
|
|
audit_entry = { |
|
|
"timestamp": datetime.utcnow().isoformat(), |
|
|
"session_id": session_id, |
|
|
"tool": tool, |
|
|
"request_hash": hashlib.sha256(json.dumps(request, sort_keys=True).encode()).hexdigest(), |
|
|
"response_hash": hashlib.sha256(json.dumps(response, sort_keys=True).encode()).hexdigest(), |
|
|
"request_summary": { |
|
|
"tool": request.get("tool"), |
|
|
"parameters": list(request.get("parameters", {}).keys()) if request.get("parameters") else [] |
|
|
} |
|
|
} |
|
|
|
|
|
self.audit_logger.info(json.dumps(audit_entry)) |
|
|
|
|
|
|
|
|
class MCPClient: |
|
|
"""Client for communicating with MCP (Model Context Protocol) servers.""" |
|
|
|
|
|
def __init__(self, mcp_server_url: str, config: Dict[str, Any]): |
|
|
self.mcp_server_url = mcp_server_url |
|
|
self.config = config |
|
|
self.session: Optional[aiohttp.ClientSession] = None |
|
|
self.available_tools: Dict[str, Any] = {} |
|
|
|
|
|
async def __aenter__(self): |
|
|
self.session = aiohttp.ClientSession() |
|
|
await self.discover_tools() |
|
|
return self |
|
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|
|
if self.session: |
|
|
await self.session.close() |
|
|
|
|
|
async def discover_tools(self): |
|
|
"""Discover available tools from MCP server.""" |
|
|
try: |
|
|
async with self.session.post(f"{self.mcp_server_url}/tools/list") as response: |
|
|
if response.status == 200: |
|
|
tools_data = await response.json() |
|
|
self.available_tools = {tool["name"]: tool for tool in tools_data.get("tools", [])} |
|
|
else: |
|
|
logging.warning(f"Failed to discover tools: {response.status}") |
|
|
except Exception as e: |
|
|
logging.error(f"Error discovering tools: {e}") |
|
|
|
|
|
async def call_tool(self, tool_name: str, parameters: Dict[str, Any], session_id: str = "default") -> Dict[str, Any]: |
|
|
"""Call a tool on the MCP server.""" |
|
|
if tool_name not in self.available_tools: |
|
|
raise ValueError(f"Tool {tool_name} not available. Available tools: {list(self.available_tools.keys())}") |
|
|
|
|
|
request = { |
|
|
"tool": tool_name, |
|
|
"parameters": parameters |
|
|
} |
|
|
|
|
|
try: |
|
|
async with self.session.post( |
|
|
f"{self.mcp_server_url}/tools/call", |
|
|
json=request |
|
|
) as response: |
|
|
if response.status == 200: |
|
|
result = await response.json() |
|
|
|
|
|
return result |
|
|
else: |
|
|
error_msg = f"MCP tool call failed: {response.status}" |
|
|
return {"error": error_msg} |
|
|
except Exception as e: |
|
|
error_msg = f"MCP tool call exception: {str(e)}" |
|
|
return {"error": error_msg} |
|
|
|
|
|
|
|
|
class BaseAgent(ABC): |
|
|
"""Base class for all AI agents with security and MCP integration.""" |
|
|
|
|
|
def __init__(self, name: str, description: str, mcp_server_url: str, config: Dict[str, Any]): |
|
|
self.name = name |
|
|
self.description = description |
|
|
self.mcp_server_url = mcp_server_url |
|
|
self.config = config |
|
|
self.security_middleware = SecurityMiddleware(config) |
|
|
self.session_counter = 0 |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
|
|
|
def _generate_session_id(self) -> str: |
|
|
"""Generate a unique session ID for tracking.""" |
|
|
self.session_counter += 1 |
|
|
return f"{self.name}_{self.session_counter}_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}" |
|
|
|
|
|
def _check_permissions(self, session_id: str, action: str) -> bool: |
|
|
"""Check if user has permission for the requested action (RBAC).""" |
|
|
|
|
|
role = self.config.get("user_roles", {}).get(session_id, "basic_user") |
|
|
|
|
|
admin_actions = ["crm_admin", "ticket_admin", "calendar_admin"] |
|
|
if action in admin_actions and role not in ["admin", "enterprise_user"]: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
@abstractmethod |
|
|
async def process_request(self, user_input: str, session_id: str = None) -> str: |
|
|
"""Process user input and return agent response.""" |
|
|
pass |
|
|
|
|
|
async def handle_user_input(self, user_input: str) -> str: |
|
|
"""Main handler for user input with security checks.""" |
|
|
session_id = self._generate_session_id() |
|
|
|
|
|
|
|
|
if self.security_middleware.detect_prompt_injection(user_input): |
|
|
return "❌ Potential security threat detected. Request blocked." |
|
|
|
|
|
try: |
|
|
response = await self.process_request(user_input, session_id) |
|
|
|
|
|
|
|
|
sanitized_response = self.security_middleware.sanitize_output(response) |
|
|
|
|
|
return sanitized_response |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error processing request: {e}") |
|
|
return f"❌ An error occurred while processing your request: {str(e)}" |
|
|
|
|
|
def get_available_tools(self) -> List[str]: |
|
|
"""Get list of available tools for this agent.""" |
|
|
|
|
|
return ["general_inquiry"] |
|
|
|
|
|
def get_status(self) -> Dict[str, Any]: |
|
|
"""Get current status of the agent.""" |
|
|
return { |
|
|
"name": self.name, |
|
|
"description": self.description, |
|
|
"mcp_server": self.mcp_server_url, |
|
|
"status": "active", |
|
|
"tools": self.get_available_tools(), |
|
|
"security_enabled": True, |
|
|
"audit_logging": True |
|
|
} |