|
|
""" |
|
|
Enterprise Prometheus Metrics for MCP Servers |
|
|
|
|
|
Features: |
|
|
- Request metrics (count, duration, errors) |
|
|
- MCP-specific metrics |
|
|
- Business metrics (prospects, contacts, emails) |
|
|
- System metrics (database connections, cache hit rate) |
|
|
""" |
|
|
import os |
|
|
import time |
|
|
import logging |
|
|
from typing import Optional |
|
|
from functools import wraps |
|
|
from aiohttp import web |
|
|
|
|
|
from prometheus_client import ( |
|
|
Counter, |
|
|
Histogram, |
|
|
Gauge, |
|
|
Summary, |
|
|
Info, |
|
|
CollectorRegistry, |
|
|
generate_latest, |
|
|
CONTENT_TYPE_LATEST |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MCPMetrics: |
|
|
"""Prometheus metrics for MCP servers""" |
|
|
|
|
|
def __init__(self, registry: Optional[CollectorRegistry] = None): |
|
|
self.registry = registry or CollectorRegistry() |
|
|
|
|
|
|
|
|
self.service_info = Info( |
|
|
'mcp_service', |
|
|
'MCP Service Information', |
|
|
registry=self.registry |
|
|
) |
|
|
self.service_info.info({ |
|
|
'service': os.getenv('SERVICE_NAME', 'cx_ai_agent'), |
|
|
'version': os.getenv('VERSION', '1.0.0'), |
|
|
'environment': os.getenv('ENVIRONMENT', 'development') |
|
|
}) |
|
|
|
|
|
|
|
|
self.http_requests_total = Counter( |
|
|
'mcp_http_requests_total', |
|
|
'Total HTTP requests', |
|
|
['method', 'path', 'status'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.http_request_duration = Histogram( |
|
|
'mcp_http_request_duration_seconds', |
|
|
'HTTP request duration in seconds', |
|
|
['method', 'path'], |
|
|
buckets=(0.001, 0.01, 0.1, 0.5, 1.0, 2.5, 5.0, 10.0), |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.http_request_size = Summary( |
|
|
'mcp_http_request_size_bytes', |
|
|
'HTTP request size in bytes', |
|
|
['method', 'path'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.http_response_size = Summary( |
|
|
'mcp_http_response_size_bytes', |
|
|
'HTTP response size in bytes', |
|
|
['method', 'path'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
|
|
|
self.mcp_calls_total = Counter( |
|
|
'mcp_calls_total', |
|
|
'Total MCP method calls', |
|
|
['server', 'method', 'status'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.mcp_call_duration = Histogram( |
|
|
'mcp_call_duration_seconds', |
|
|
'MCP call duration in seconds', |
|
|
['server', 'method'], |
|
|
buckets=(0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0), |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
|
|
|
self.prospects_total = Gauge( |
|
|
'mcp_prospects_total', |
|
|
'Total number of prospects', |
|
|
['status', 'tenant_id'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.contacts_total = Gauge( |
|
|
'mcp_contacts_total', |
|
|
'Total number of contacts', |
|
|
['tenant_id'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.companies_total = Gauge( |
|
|
'mcp_companies_total', |
|
|
'Total number of companies', |
|
|
['tenant_id'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.emails_sent_total = Counter( |
|
|
'mcp_emails_sent_total', |
|
|
'Total emails sent', |
|
|
['tenant_id'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.meetings_booked_total = Counter( |
|
|
'mcp_meetings_booked_total', |
|
|
'Total meetings booked', |
|
|
['tenant_id'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
|
|
|
self.db_connections = Gauge( |
|
|
'mcp_db_connections', |
|
|
'Number of active database connections', |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.db_queries_total = Counter( |
|
|
'mcp_db_queries_total', |
|
|
'Total database queries', |
|
|
['operation', 'table'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.db_query_duration = Histogram( |
|
|
'mcp_db_query_duration_seconds', |
|
|
'Database query duration', |
|
|
['operation', 'table'], |
|
|
buckets=(0.001, 0.01, 0.05, 0.1, 0.5, 1.0), |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
|
|
|
self.cache_hits_total = Counter( |
|
|
'mcp_cache_hits_total', |
|
|
'Total cache hits', |
|
|
['cache_name'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.cache_misses_total = Counter( |
|
|
'mcp_cache_misses_total', |
|
|
'Total cache misses', |
|
|
['cache_name'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
|
|
|
self.auth_attempts_total = Counter( |
|
|
'mcp_auth_attempts_total', |
|
|
'Total authentication attempts', |
|
|
['result'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
self.rate_limit_exceeded_total = Counter( |
|
|
'mcp_rate_limit_exceeded_total', |
|
|
'Total rate limit exceeded events', |
|
|
['client_id', 'path'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
|
|
|
self.errors_total = Counter( |
|
|
'mcp_errors_total', |
|
|
'Total errors', |
|
|
['error_type', 'component'], |
|
|
registry=self.registry |
|
|
) |
|
|
|
|
|
logger.info("Prometheus metrics initialized") |
|
|
|
|
|
def record_http_request( |
|
|
self, |
|
|
method: str, |
|
|
path: str, |
|
|
status: int, |
|
|
duration: float, |
|
|
request_size: Optional[int] = None, |
|
|
response_size: Optional[int] = None |
|
|
): |
|
|
"""Record HTTP request metrics""" |
|
|
self.http_requests_total.labels(method=method, path=path, status=status).inc() |
|
|
self.http_request_duration.labels(method=method, path=path).observe(duration) |
|
|
|
|
|
if request_size: |
|
|
self.http_request_size.labels(method=method, path=path).observe(request_size) |
|
|
if response_size: |
|
|
self.http_response_size.labels(method=method, path=path).observe(response_size) |
|
|
|
|
|
def record_mcp_call( |
|
|
self, |
|
|
server: str, |
|
|
method: str, |
|
|
duration: float, |
|
|
success: bool = True |
|
|
): |
|
|
"""Record MCP call metrics""" |
|
|
status = 'success' if success else 'error' |
|
|
self.mcp_calls_total.labels(server=server, method=method, status=status).inc() |
|
|
self.mcp_call_duration.labels(server=server, method=method).observe(duration) |
|
|
|
|
|
def record_db_query( |
|
|
self, |
|
|
operation: str, |
|
|
table: str, |
|
|
duration: float |
|
|
): |
|
|
"""Record database query metrics""" |
|
|
self.db_queries_total.labels(operation=operation, table=table).inc() |
|
|
self.db_query_duration.labels(operation=operation, table=table).observe(duration) |
|
|
|
|
|
def record_cache_access(self, cache_name: str, hit: bool): |
|
|
"""Record cache access""" |
|
|
if hit: |
|
|
self.cache_hits_total.labels(cache_name=cache_name).inc() |
|
|
else: |
|
|
self.cache_misses_total.labels(cache_name=cache_name).inc() |
|
|
|
|
|
def record_auth_attempt(self, result: str): |
|
|
"""Record authentication attempt""" |
|
|
self.auth_attempts_total.labels(result=result).inc() |
|
|
|
|
|
def record_rate_limit_exceeded(self, client_id: str, path: str): |
|
|
"""Record rate limit exceeded""" |
|
|
self.rate_limit_exceeded_total.labels(client_id=client_id, path=path).inc() |
|
|
|
|
|
def record_error(self, error_type: str, component: str): |
|
|
"""Record error""" |
|
|
self.errors_total.labels(error_type=error_type, component=component).inc() |
|
|
|
|
|
|
|
|
class MetricsMiddleware: |
|
|
"""aiohttp middleware for automatic metrics collection""" |
|
|
|
|
|
def __init__(self, metrics: MCPMetrics): |
|
|
self.metrics = metrics |
|
|
logger.info("Metrics middleware initialized") |
|
|
|
|
|
@web.middleware |
|
|
async def middleware(self, request: web.Request, handler): |
|
|
"""Middleware handler""" |
|
|
|
|
|
|
|
|
if request.path == '/metrics': |
|
|
return await handler(request) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
request_size = request.content_length or 0 |
|
|
|
|
|
|
|
|
response = await handler(request) |
|
|
|
|
|
|
|
|
duration = time.time() - start_time |
|
|
|
|
|
|
|
|
response_size = len(response.body) if hasattr(response, 'body') and response.body else 0 |
|
|
|
|
|
|
|
|
self.metrics.record_http_request( |
|
|
method=request.method, |
|
|
path=request.path, |
|
|
status=response.status, |
|
|
duration=duration, |
|
|
request_size=request_size, |
|
|
response_size=response_size |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
duration = time.time() - start_time |
|
|
self.metrics.record_http_request( |
|
|
method=request.method, |
|
|
path=request.path, |
|
|
status=500, |
|
|
duration=duration |
|
|
) |
|
|
self.metrics.record_error( |
|
|
error_type=type(e).__name__, |
|
|
component='http_handler' |
|
|
) |
|
|
raise |
|
|
|
|
|
|
|
|
def metrics_endpoint(metrics: MCPMetrics): |
|
|
""" |
|
|
Create metrics endpoint handler |
|
|
|
|
|
Returns: |
|
|
aiohttp handler function |
|
|
""" |
|
|
async def handler(request: web.Request): |
|
|
"""Serve Prometheus metrics""" |
|
|
metrics_output = generate_latest(metrics.registry) |
|
|
return web.Response( |
|
|
body=metrics_output, |
|
|
content_type=CONTENT_TYPE_LATEST |
|
|
) |
|
|
|
|
|
return handler |
|
|
|
|
|
|
|
|
def track_mcp_call(metrics: MCPMetrics, server: str): |
|
|
""" |
|
|
Decorator to track MCP call metrics |
|
|
|
|
|
Usage: |
|
|
@track_mcp_call(metrics, "search") |
|
|
async def search_query(query: str): |
|
|
... |
|
|
""" |
|
|
def decorator(func): |
|
|
@wraps(func) |
|
|
async def wrapper(*args, **kwargs): |
|
|
start_time = time.time() |
|
|
success = True |
|
|
|
|
|
try: |
|
|
result = await func(*args, **kwargs) |
|
|
return result |
|
|
except Exception as e: |
|
|
success = False |
|
|
raise |
|
|
finally: |
|
|
duration = time.time() - start_time |
|
|
metrics.record_mcp_call( |
|
|
server=server, |
|
|
method=func.__name__, |
|
|
duration=duration, |
|
|
success=success |
|
|
) |
|
|
|
|
|
return wrapper |
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
_metrics: Optional[MCPMetrics] = None |
|
|
|
|
|
|
|
|
def get_metrics() -> MCPMetrics: |
|
|
"""Get or create global metrics instance""" |
|
|
global _metrics |
|
|
if _metrics is None: |
|
|
_metrics = MCPMetrics() |
|
|
return _metrics |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
metrics = get_metrics() |
|
|
|
|
|
|
|
|
metrics.record_http_request("POST", "/rpc", 200, 0.05, 1024, 2048) |
|
|
metrics.record_mcp_call("search", "search.query", 0.1, success=True) |
|
|
metrics.record_db_query("SELECT", "prospects", 0.02) |
|
|
metrics.record_cache_access("company_cache", hit=True) |
|
|
metrics.record_auth_attempt("success") |
|
|
|
|
|
|
|
|
print(generate_latest(metrics.registry).decode()) |
|
|
|