muzakkirhussain011's picture
Add application files (text files only)
8bab08d
"""
Enterprise Prometheus Metrics for MCP Servers
Features:
- Request metrics (count, duration, errors)
- MCP-specific metrics
- Business metrics (prospects, contacts, emails)
- System metrics (database connections, cache hit rate)
"""
import os
import time
import logging
from typing import Optional
from functools import wraps
from aiohttp import web
from prometheus_client import (
Counter,
Histogram,
Gauge,
Summary,
Info,
CollectorRegistry,
generate_latest,
CONTENT_TYPE_LATEST
)
logger = logging.getLogger(__name__)
class MCPMetrics:
"""Prometheus metrics for MCP servers"""
def __init__(self, registry: Optional[CollectorRegistry] = None):
self.registry = registry or CollectorRegistry()
# Service info
self.service_info = Info(
'mcp_service',
'MCP Service Information',
registry=self.registry
)
self.service_info.info({
'service': os.getenv('SERVICE_NAME', 'cx_ai_agent'),
'version': os.getenv('VERSION', '1.0.0'),
'environment': os.getenv('ENVIRONMENT', 'development')
})
# HTTP Request Metrics
self.http_requests_total = Counter(
'mcp_http_requests_total',
'Total HTTP requests',
['method', 'path', 'status'],
registry=self.registry
)
self.http_request_duration = Histogram(
'mcp_http_request_duration_seconds',
'HTTP request duration in seconds',
['method', 'path'],
buckets=(0.001, 0.01, 0.1, 0.5, 1.0, 2.5, 5.0, 10.0),
registry=self.registry
)
self.http_request_size = Summary(
'mcp_http_request_size_bytes',
'HTTP request size in bytes',
['method', 'path'],
registry=self.registry
)
self.http_response_size = Summary(
'mcp_http_response_size_bytes',
'HTTP response size in bytes',
['method', 'path'],
registry=self.registry
)
# MCP-Specific Metrics
self.mcp_calls_total = Counter(
'mcp_calls_total',
'Total MCP method calls',
['server', 'method', 'status'],
registry=self.registry
)
self.mcp_call_duration = Histogram(
'mcp_call_duration_seconds',
'MCP call duration in seconds',
['server', 'method'],
buckets=(0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0),
registry=self.registry
)
# Business Metrics
self.prospects_total = Gauge(
'mcp_prospects_total',
'Total number of prospects',
['status', 'tenant_id'],
registry=self.registry
)
self.contacts_total = Gauge(
'mcp_contacts_total',
'Total number of contacts',
['tenant_id'],
registry=self.registry
)
self.companies_total = Gauge(
'mcp_companies_total',
'Total number of companies',
['tenant_id'],
registry=self.registry
)
self.emails_sent_total = Counter(
'mcp_emails_sent_total',
'Total emails sent',
['tenant_id'],
registry=self.registry
)
self.meetings_booked_total = Counter(
'mcp_meetings_booked_total',
'Total meetings booked',
['tenant_id'],
registry=self.registry
)
# Database Metrics
self.db_connections = Gauge(
'mcp_db_connections',
'Number of active database connections',
registry=self.registry
)
self.db_queries_total = Counter(
'mcp_db_queries_total',
'Total database queries',
['operation', 'table'],
registry=self.registry
)
self.db_query_duration = Histogram(
'mcp_db_query_duration_seconds',
'Database query duration',
['operation', 'table'],
buckets=(0.001, 0.01, 0.05, 0.1, 0.5, 1.0),
registry=self.registry
)
# Cache Metrics (for Redis)
self.cache_hits_total = Counter(
'mcp_cache_hits_total',
'Total cache hits',
['cache_name'],
registry=self.registry
)
self.cache_misses_total = Counter(
'mcp_cache_misses_total',
'Total cache misses',
['cache_name'],
registry=self.registry
)
# Authentication Metrics
self.auth_attempts_total = Counter(
'mcp_auth_attempts_total',
'Total authentication attempts',
['result'], # success, failed, expired
registry=self.registry
)
self.rate_limit_exceeded_total = Counter(
'mcp_rate_limit_exceeded_total',
'Total rate limit exceeded events',
['client_id', 'path'],
registry=self.registry
)
# Error Metrics
self.errors_total = Counter(
'mcp_errors_total',
'Total errors',
['error_type', 'component'],
registry=self.registry
)
logger.info("Prometheus metrics initialized")
def record_http_request(
self,
method: str,
path: str,
status: int,
duration: float,
request_size: Optional[int] = None,
response_size: Optional[int] = None
):
"""Record HTTP request metrics"""
self.http_requests_total.labels(method=method, path=path, status=status).inc()
self.http_request_duration.labels(method=method, path=path).observe(duration)
if request_size:
self.http_request_size.labels(method=method, path=path).observe(request_size)
if response_size:
self.http_response_size.labels(method=method, path=path).observe(response_size)
def record_mcp_call(
self,
server: str,
method: str,
duration: float,
success: bool = True
):
"""Record MCP call metrics"""
status = 'success' if success else 'error'
self.mcp_calls_total.labels(server=server, method=method, status=status).inc()
self.mcp_call_duration.labels(server=server, method=method).observe(duration)
def record_db_query(
self,
operation: str,
table: str,
duration: float
):
"""Record database query metrics"""
self.db_queries_total.labels(operation=operation, table=table).inc()
self.db_query_duration.labels(operation=operation, table=table).observe(duration)
def record_cache_access(self, cache_name: str, hit: bool):
"""Record cache access"""
if hit:
self.cache_hits_total.labels(cache_name=cache_name).inc()
else:
self.cache_misses_total.labels(cache_name=cache_name).inc()
def record_auth_attempt(self, result: str):
"""Record authentication attempt"""
self.auth_attempts_total.labels(result=result).inc()
def record_rate_limit_exceeded(self, client_id: str, path: str):
"""Record rate limit exceeded"""
self.rate_limit_exceeded_total.labels(client_id=client_id, path=path).inc()
def record_error(self, error_type: str, component: str):
"""Record error"""
self.errors_total.labels(error_type=error_type, component=component).inc()
class MetricsMiddleware:
"""aiohttp middleware for automatic metrics collection"""
def __init__(self, metrics: MCPMetrics):
self.metrics = metrics
logger.info("Metrics middleware initialized")
@web.middleware
async def middleware(self, request: web.Request, handler):
"""Middleware handler"""
# Skip metrics endpoint itself
if request.path == '/metrics':
return await handler(request)
start_time = time.time()
try:
# Get request size
request_size = request.content_length or 0
# Process request
response = await handler(request)
# Calculate duration
duration = time.time() - start_time
# Get response size
response_size = len(response.body) if hasattr(response, 'body') and response.body else 0
# Record metrics
self.metrics.record_http_request(
method=request.method,
path=request.path,
status=response.status,
duration=duration,
request_size=request_size,
response_size=response_size
)
return response
except Exception as e:
# Record error
duration = time.time() - start_time
self.metrics.record_http_request(
method=request.method,
path=request.path,
status=500,
duration=duration
)
self.metrics.record_error(
error_type=type(e).__name__,
component='http_handler'
)
raise
def metrics_endpoint(metrics: MCPMetrics):
"""
Create metrics endpoint handler
Returns:
aiohttp handler function
"""
async def handler(request: web.Request):
"""Serve Prometheus metrics"""
metrics_output = generate_latest(metrics.registry)
return web.Response(
body=metrics_output,
content_type=CONTENT_TYPE_LATEST
)
return handler
def track_mcp_call(metrics: MCPMetrics, server: str):
"""
Decorator to track MCP call metrics
Usage:
@track_mcp_call(metrics, "search")
async def search_query(query: str):
...
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
success = True
try:
result = await func(*args, **kwargs)
return result
except Exception as e:
success = False
raise
finally:
duration = time.time() - start_time
metrics.record_mcp_call(
server=server,
method=func.__name__,
duration=duration,
success=success
)
return wrapper
return decorator
# Global metrics instance
_metrics: Optional[MCPMetrics] = None
def get_metrics() -> MCPMetrics:
"""Get or create global metrics instance"""
global _metrics
if _metrics is None:
_metrics = MCPMetrics()
return _metrics
# Example usage
if __name__ == "__main__":
metrics = get_metrics()
# Simulate some metrics
metrics.record_http_request("POST", "/rpc", 200, 0.05, 1024, 2048)
metrics.record_mcp_call("search", "search.query", 0.1, success=True)
metrics.record_db_query("SELECT", "prospects", 0.02)
metrics.record_cache_access("company_cache", hit=True)
metrics.record_auth_attempt("success")
# Generate metrics output
print(generate_latest(metrics.registry).decode())