BeatDebate / src /api /logging_middleware.py
SulmanK's picture
Enhance PlannerAgent with entity recognition and coordination strategies - Implemented enhanced entity recognition and intent analysis in the PlannerAgent, allowing for improved query processing and coordination strategies. Updated the MusicRecommenderState model to include new fields for extracted entities, intent analysis, and conversation context. Enhanced logging and reasoning steps for better traceability during planning. Revised design documentation to reflect these changes and outline future enhancements.
00dd8ea
Raw
History Blame Contribute Delete
8.32 kB
"""
FastAPI Logging Middleware for BeatDebate
Provides comprehensive logging for all API requests and responses including:
- Request/response timing
- Status codes and error tracking
- Request IDs for tracing
- Performance monitoring
"""
import time
import uuid
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
try:
from ..utils.logging_config import get_logger, log_api_request, set_request_context
except RuntimeError:
# Logging not initialized yet - will use fallback
get_logger = None
log_api_request = None
set_request_context = None
class LoggingMiddleware(BaseHTTPMiddleware):
"""
Middleware to log all API requests and responses.
Provides:
- Request/response timing
- Status code tracking
- Error logging
- Request ID generation for tracing
- Performance monitoring
"""
def __init__(self, app, exclude_paths: list = None):
"""
Initialize logging middleware.
Args:
app: FastAPI application
exclude_paths: List of paths to exclude from logging
"""
super().__init__(app)
self.logger = get_logger("api.middleware") if get_logger else None
self.exclude_paths = exclude_paths or ["/health", "/docs", "/openapi.json"]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process request and response with logging."""
# Skip logging for excluded paths
if request.url.path in self.exclude_paths:
return await call_next(request)
# Generate request ID for tracing
request_id = str(uuid.uuid4())
# Set request context for structured logging
if set_request_context:
set_request_context(
request_id=request_id,
user_id=self._extract_user_id(request)
)
# Log request start
start_time = time.time()
if self.logger:
self.logger.info(
"api_request_start",
request_id=request_id,
method=request.method,
url=str(request.url),
headers=dict(request.headers),
client_ip=self._get_client_ip(request)
)
# Process request
try:
response = await call_next(request)
# Calculate duration
duration = time.time() - start_time
# Log successful response
if self.logger:
self.logger.info(
"api_request_complete",
request_id=request_id,
method=request.method,
url=str(request.url),
status_code=response.status_code,
duration_seconds=round(duration, 4),
response_size=response.headers.get("content-length", "unknown")
)
# Log to performance metrics
if log_api_request:
log_api_request(
method=request.method,
url=str(request.url),
status_code=response.status_code,
duration=duration,
request_id=request_id
)
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
return response
except Exception as e:
# Calculate duration for error case
duration = time.time() - start_time
# Log error
if self.logger:
self.logger.error(
"api_request_error",
request_id=request_id,
method=request.method,
url=str(request.url),
error_type=type(e).__name__,
error_message=str(e),
duration_seconds=round(duration, 4)
)
# Re-raise the exception
raise
def _extract_user_id(self, request: Request) -> str:
"""Extract user ID from request if available."""
# Check for user ID in headers, query params, or session
user_id = (
request.headers.get("X-User-ID") or
request.query_params.get("user_id")
)
# Try to get from session if available (only if SessionMiddleware is installed)
try:
if hasattr(request, "session") and "session" in request.scope:
user_id = user_id or request.session.get("user_id")
except (AttributeError, AssertionError):
# SessionMiddleware not installed or session not available
pass
return user_id or "anonymous"
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address from request."""
# Check for forwarded headers first (for proxy/load balancer)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fall back to direct client IP
return request.client.host if request.client else "unknown"
class PerformanceLoggingMiddleware(BaseHTTPMiddleware):
"""
Specialized middleware for performance monitoring.
Tracks:
- Slow requests (configurable threshold)
- Memory usage
- Database query times
- External API call times
"""
def __init__(self, app, slow_request_threshold: float = 5.0):
"""
Initialize performance logging middleware.
Args:
app: FastAPI application
slow_request_threshold: Time in seconds to consider a request slow
"""
super().__init__(app)
self.logger = get_logger("performance") if get_logger else None
self.slow_request_threshold = slow_request_threshold
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process request with performance monitoring."""
start_time = time.time()
try:
response = await call_next(request)
duration = time.time() - start_time
# Log slow requests
if duration > self.slow_request_threshold and self.logger:
self.logger.warning(
"slow_request",
method=request.method,
url=str(request.url),
duration_seconds=round(duration, 4),
threshold_seconds=self.slow_request_threshold
)
# Log performance metrics for specific endpoints
if self._should_track_performance(request.url.path) and self.logger:
self.logger.info(
"endpoint_performance",
endpoint=request.url.path,
method=request.method,
duration_seconds=round(duration, 4),
status_code=response.status_code
)
return response
except Exception as e:
duration = time.time() - start_time
if self.logger:
self.logger.error(
"request_exception",
method=request.method,
url=str(request.url),
duration_seconds=round(duration, 4),
error_type=type(e).__name__,
error_message=str(e)
)
raise
def _should_track_performance(self, path: str) -> bool:
"""Determine if we should track performance for this endpoint."""
# Track performance for key endpoints
tracked_endpoints = [
"/api/recommendations",
"/api/chat",
"/api/search"
]
return any(tracked in path for tracked in tracked_endpoints)