Spaces:
Sleeping
Sleeping
| """ | |
| Distributed tracing integration for MediGuard AI. | |
| Uses OpenTelemetry for end-to-end request tracing. | |
| """ | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from contextlib import asynccontextmanager | |
| from typing import Any | |
| from opentelemetry import baggage, context, trace | |
| from opentelemetry.exporter.jaeger.thrift import JaegerExporter | |
| from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter | |
| from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor | |
| from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor | |
| from opentelemetry.instrumentation.redis import RedisInstrumentor | |
| from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor | |
| from opentelemetry.propagate import set_global_textmap | |
| from opentelemetry.sdk.trace import TracerProvider | |
| from opentelemetry.sdk.trace.export import BatchSpanProcessor | |
| from opentelemetry.semconv.trace import SpanAttributes | |
| from opentelemetry.trace import SpanKind, Status, StatusCode | |
| from src.settings import get_settings | |
| logger = logging.getLogger(__name__) | |
| class DistributedTracer: | |
| """Manages distributed tracing configuration and operations.""" | |
| def __init__(self): | |
| self.tracer_provider: TracerProvider | None = None | |
| self.is_initialized = False | |
| self.service_name = "mediguard-api" | |
| self.service_version = "2.0.0" | |
| def initialize(self): | |
| """Initialize OpenTelemetry tracing.""" | |
| if self.is_initialized: | |
| return | |
| settings = get_settings() | |
| # Set up tracer provider | |
| self.tracer_provider = TracerProvider() | |
| trace.set_tracer_provider(self.tracer_provider) | |
| # Configure exporters based on environment | |
| exporters = [] | |
| # Jaeger exporter | |
| if os.getenv("JAEGER_ENDPOINT"): | |
| jaeger_exporter = JaegerExporter( | |
| endpoint=os.getenv("JAEGER_ENDPOINT"), | |
| collector_endpoint=os.getenv("JAEGER_COLLECTOR_ENDPOINT"), | |
| agent_host_name=os.getenv("JAEGER_AGENT_HOST", "localhost"), | |
| agent_port=int(os.getenv("JAEGER_AGENT_PORT", "6831")), | |
| ) | |
| exporters.append(jaeger_exporter) | |
| logger.info("Jaeger tracing enabled") | |
| # OTLP exporter (for services like Tempo, Honeycomb, etc.) | |
| if os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT"): | |
| otlp_exporter = OTLPSpanExporter( | |
| endpoint=os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT"), | |
| headers=json.loads(os.getenv("OTEL_EXPORTER_OTLP_HEADERS", "{}")), | |
| ) | |
| exporters.append(otlp_exporter) | |
| logger.info("OTLP tracing enabled") | |
| # Add processors for each exporter | |
| for exporter in exporters: | |
| processor = BatchSpanProcessor(exporter) | |
| self.tracer_provider.add_span_processor(processor) | |
| # Set global propagator | |
| set_global_textmap({}) | |
| # Instrument libraries | |
| self._instrument_libraries() | |
| self.is_initialized = True | |
| logger.info("Distributed tracing initialized") | |
| def _instrument_libraries(self): | |
| """Instrument common libraries for automatic tracing.""" | |
| # FastAPI | |
| try: | |
| FastAPIInstrumentor.instrument_app( | |
| app=None, # Will be set when app is created | |
| tracer_provider=self.tracer_provider, | |
| excluded_urls=[ | |
| "/health", | |
| "/metrics", | |
| "/docs", | |
| "/redoc", | |
| "/openapi.json" | |
| ] | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to instrument FastAPI: {e}") | |
| # HTTPX | |
| try: | |
| HTTPXClientInstrumentor().instrument() | |
| except Exception as e: | |
| logger.warning(f"Failed to instrument HTTPX: {e}") | |
| # Redis | |
| try: | |
| RedisInstrumentor().instrument() | |
| except Exception as e: | |
| logger.warning(f"Failed to instrument Redis: {e}") | |
| # SQLAlchemy (if used) | |
| try: | |
| SQLAlchemyInstrumentor().instrument() | |
| except Exception as e: | |
| logger.warning(f"Failed to instrument SQLAlchemy: {e}") | |
| def get_tracer(self, name: str = None): | |
| """Get a tracer instance.""" | |
| if not self.is_initialized: | |
| self.initialize() | |
| return trace.get_tracer(name or self.service_name) | |
| def shutdown(self): | |
| """Shutdown the tracer provider.""" | |
| if self.tracer_provider: | |
| self.tracer_provider.shutdown() | |
| self.is_initialized = False | |
| logger.info("Distributed tracing shutdown") | |
| # Global tracer instance | |
| _distributed_tracer = DistributedTracer() | |
| def get_distributed_tracer() -> DistributedTracer: | |
| """Get the global distributed tracer instance.""" | |
| return _distributed_tracer | |
| class TraceContext: | |
| """Helper class for managing trace context.""" | |
| def get_current_span() -> trace.Span: | |
| """Get the current span.""" | |
| return trace.get_current_span() | |
| def get_trace_id() -> str | None: | |
| """Get the current trace ID.""" | |
| span = trace.get_current_span() | |
| if span: | |
| span_context = span.get_span_context() | |
| if span_context.is_valid: | |
| return format(span_context.trace_id, "032x") | |
| return None | |
| def get_span_id() -> str | None: | |
| """Get the current span ID.""" | |
| span = trace.get_current_span() | |
| if span: | |
| span_context = span.get_span_context() | |
| if span_context.is_valid: | |
| return format(span_context.span_id, "016x") | |
| return None | |
| def set_baggage(key: str, value: str): | |
| """Set baggage item.""" | |
| baggage.set_baggage(key, value) | |
| def get_baggage(key: str) -> str | None: | |
| """Get baggage item.""" | |
| return baggage.get_baggage(key) | |
| def inject_headers(headers: dict[str, str]): | |
| """Inject trace context into headers.""" | |
| ctx = context.get_current() | |
| carrier = {} | |
| set_global_textmap().inject(carrier, ctx) | |
| headers.update(carrier) | |
| def extract_from_headers(headers: dict[str, str]): | |
| """Extract trace context from headers.""" | |
| ctx = set_global_textmap().extract(headers) | |
| return ctx | |
| async def trace_span( | |
| name: str, | |
| kind: SpanKind = SpanKind.INTERNAL, | |
| attributes: dict[str, Any] | None = None | |
| ): | |
| """Context manager for creating spans.""" | |
| tracer = get_distributed_tracer().get_tracer() | |
| with tracer.start_as_current_span(name, kind=kind) as span: | |
| if attributes: | |
| for key, value in attributes.items(): | |
| span.set_attribute(key, str(value)) | |
| yield span | |
| def trace_function( | |
| name: str = None, | |
| kind: SpanKind = SpanKind.INTERNAL, | |
| attributes: dict[str, Any] | None = None | |
| ): | |
| """Decorator for tracing functions.""" | |
| def decorator(func): | |
| import asyncio | |
| import functools | |
| span_name = name or f"{func.__module__}.{func.__name__}" | |
| if asyncio.iscoroutinefunction(func): | |
| async def async_wrapper(*args, **kwargs): | |
| tracer = get_distributed_tracer().get_tracer() | |
| with tracer.start_as_current_span(span_name, kind=kind) as span: | |
| if attributes: | |
| for key, value in attributes.items(): | |
| span.set_attribute(key, str(value)) | |
| # Add function arguments as attributes (be careful with sensitive data) | |
| span.set_attribute("function.name", func.__name__) | |
| span.set_attribute("function.module", func.__module__) | |
| try: | |
| result = await func(*args, **kwargs) | |
| span.set_status(Status(StatusCode.OK)) | |
| return result | |
| except Exception as e: | |
| span.set_status(Status(StatusCode.ERROR, str(e))) | |
| span.record_exception(e) | |
| raise | |
| return async_wrapper | |
| else: | |
| def sync_wrapper(*args, **kwargs): | |
| tracer = get_distributed_tracer().get_tracer() | |
| with tracer.start_as_current_span(span_name, kind=kind) as span: | |
| if attributes: | |
| for key, value in attributes.items(): | |
| span.set_attribute(key, str(value)) | |
| span.set_attribute("function.name", func.__name__) | |
| span.set_attribute("function.module", func.__module__) | |
| try: | |
| result = func(*args, **kwargs) | |
| span.set_status(Status(StatusCode.OK)) | |
| return result | |
| except Exception as e: | |
| span.set_status(Status(StatusCode.ERROR, str(e))) | |
| span.record_exception(e) | |
| raise | |
| return sync_wrapper | |
| return decorator | |
| class TracingMiddleware: | |
| """Custom middleware for enhanced tracing.""" | |
| def __init__(self, app): | |
| self.app = app | |
| async def __call__(self, scope, receive, send): | |
| """ASGI middleware implementation.""" | |
| if scope["type"] != "http": | |
| await self.app(scope, receive, send) | |
| return | |
| # Get tracer | |
| tracer = get_distributed_tracer().get_tracer("asgi") | |
| # Extract trace context from headers | |
| headers = dict(scope.get("headers", [])) | |
| ctx = TraceContext.extract_from_headers(headers) | |
| with tracer.start_as_current_span( | |
| f"{scope['method']} {scope['path']}", | |
| kind=SpanKind.SERVER, | |
| context=ctx | |
| ) as span: | |
| # Set standard attributes | |
| span.set_attribute(SpanAttributes.HTTP_METHOD, scope["method"]) | |
| span.set_attribute(SpanAttributes.HTTP_URL, scope.get("path", "")) | |
| span.set_attribute(SpanAttributes.HTTP_SCHEME, scope.get("scheme", "http")) | |
| span.set_attribute(SpanAttributes.HTTP_HOST, scope.get("server", ("", ""))[0]) | |
| span.set_attribute(SpanAttributes.HTTP_USER_AGENT, self._get_header(headers, b"user-agent")) | |
| span.set_attribute(SpanAttributes.HTTP_CLIENT_IP, self._get_client_ip(scope)) | |
| # Add custom baggage | |
| TraceContext.set_baggage("service.name", "mediguard-api") | |
| TraceContext.set_baggage("service.version", "2.0.0") | |
| # Capture start time | |
| start_time = time.time() | |
| # Wrap send to capture response | |
| async def traced_send(message): | |
| if message["type"] == "http.response.start": | |
| # Set response attributes | |
| status = message.get("status", 200) | |
| span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, status) | |
| # Mark error status | |
| if status >= 400: | |
| span.set_status(Status(StatusCode.ERROR)) | |
| else: | |
| span.set_status(Status(StatusCode.OK)) | |
| elif message["type"] == "http.response.body": | |
| # Calculate duration | |
| duration = time.time() - start_time | |
| span.set_attribute("http.response.duration_ms", duration * 1000) | |
| await send(message) | |
| await self.app(scope, receive, traced_send) | |
| def _get_header(self, headers: dict[bytes, bytes], name: bytes) -> str | None: | |
| """Get header value by name.""" | |
| for key, value in headers.items(): | |
| if key.lower() == name.lower(): | |
| return value.decode("utf-8") | |
| return None | |
| def _get_client_ip(self, scope: dict[str, Any]) -> str | None: | |
| """Extract client IP from scope.""" | |
| # Check for forwarded headers | |
| headers = dict(scope.get("headers", [])) | |
| # X-Forwarded-For | |
| xff = self._get_header(headers, b"x-forwarded-for") | |
| if xff: | |
| return xff.split(",")[0].strip() | |
| # X-Real-IP | |
| xri = self._get_header(headers, b"x-real-ip") | |
| if xri: | |
| return xri | |
| # Client from scope | |
| client = scope.get("client") | |
| if client: | |
| return client[0] | |
| return None | |
| # Specialized tracing for different components | |
| class DatabaseTracer: | |
| """Tracing utilities for database operations.""" | |
| async def trace_query(query: str, params: dict[str, Any] = None): | |
| """Trace a database query.""" | |
| span = TraceContext.get_current_span() | |
| span.set_attribute("db.query", query) | |
| span.set_attribute("db.system", "opensearch") | |
| if params: | |
| span.set_attribute("db.params", str(params)) | |
| async def trace_bulk_operation(operation: str, count: int): | |
| """Trace a bulk database operation.""" | |
| span = TraceContext.get_current_span() | |
| span.set_attribute("db.operation", operation) | |
| span.set_attribute("db.bulk_count", count) | |
| class LLTracer: | |
| """Tracing utilities for LLM operations.""" | |
| async def trace_llm_call( | |
| model: str, | |
| prompt: str, | |
| response: str = None, | |
| tokens: dict[str, int] = None | |
| ): | |
| """Trace an LLM API call.""" | |
| span = TraceContext.get_current_span() | |
| span.set_attribute("llm.model", model) | |
| span.set_attribute("llm.prompt", prompt[:1000]) # Truncate for privacy | |
| span.set_attribute("llm.provider", "openai") | |
| if response: | |
| span.set_attribute("llm.response", response[:1000]) | |
| if tokens: | |
| span.set_attribute("llm.tokens.prompt", tokens.get("prompt", 0)) | |
| span.set_attribute("llm.tokens.completion", tokens.get("completion", 0)) | |
| span.set_attribute("llm.tokens.total", tokens.get("total", 0)) | |
| class CacheTracer: | |
| """Tracing utilities for cache operations.""" | |
| async def trace_cache_operation(operation: str, key: str, hit: bool = None): | |
| """Trace a cache operation.""" | |
| span = TraceContext.get_current_span() | |
| span.set_attribute("cache.operation", operation) | |
| span.set_attribute("cache.key", key) | |
| span.set_attribute("cache.system", "redis") | |
| if hit is not None: | |
| span.set_attribute("cache.hit", hit) | |
| class WorkflowTracer: | |
| """Tracing utilities for workflow operations.""" | |
| async def trace_workflow_step( | |
| workflow_name: str, | |
| step_name: str, | |
| step_duration: float, | |
| success: bool | |
| ): | |
| """Trace a workflow step.""" | |
| span = TraceContext.get_current_span() | |
| span.set_attribute("workflow.name", workflow_name) | |
| span.set_attribute("workflow.step", step_name) | |
| span.set_attribute("workflow.step.duration_ms", step_duration * 1000) | |
| span.set_attribute("workflow.step.success", success) | |
| # Integration with existing services | |
| async def trace_http_request( | |
| method: str, | |
| url: str, | |
| headers: dict[str, str] = None, | |
| status_code: int = None, | |
| duration_ms: float = None | |
| ): | |
| """Trace an HTTP request.""" | |
| with trace_span( | |
| f"HTTP {method}", | |
| kind=SpanKind.CLIENT, | |
| attributes={ | |
| "http.method": method, | |
| "http.url": url, | |
| "http.status_code": status_code, | |
| "http.duration_ms": duration_ms | |
| } | |
| ) as span: | |
| if status_code and status_code >= 400: | |
| span.set_status(Status(StatusCode.ERROR)) | |
| # Metrics integration | |
| class TraceMetrics: | |
| """Extract metrics from traces.""" | |
| def __init__(self): | |
| self.request_counts: dict[str, int] = {} | |
| self.error_counts: dict[str, int] = {} | |
| self.response_times: dict[str, list[float]] = {} | |
| def record_span(self, span_data: dict[str, Any]): | |
| """Record span data for metrics.""" | |
| name = span_data.get("name", "") | |
| duration = span_data.get("duration_ms", 0) | |
| status = span_data.get("status", "ok") | |
| # Count requests | |
| self.request_counts[name] = self.request_counts.get(name, 0) + 1 | |
| # Count errors | |
| if status != "ok": | |
| self.error_counts[name] = self.error_counts.get(name, 0) + 1 | |
| # Track response times | |
| if name not in self.response_times: | |
| self.response_times[name] = [] | |
| self.response_times[name].append(duration) | |
| def get_metrics(self) -> dict[str, Any]: | |
| """Get aggregated metrics.""" | |
| return { | |
| "request_counts": self.request_counts, | |
| "error_counts": self.error_counts, | |
| "avg_response_times": { | |
| name: sum(times) / len(times) | |
| for name, times in self.response_times.items() | |
| if times | |
| } | |
| } | |
| # Initialization function for FastAPI app | |
| def initialize_tracing(app): | |
| """Initialize tracing for FastAPI application.""" | |
| # Initialize distributed tracer | |
| tracer = get_distributed_tracer() | |
| tracer.initialize() | |
| # Add custom middleware | |
| app.add_middleware(TracingMiddleware) | |
| # Instrument FastAPI | |
| FastAPIInstrumentor.instrument_app(app) | |
| logger.info("Tracing initialized for FastAPI application") | |