Spaces:
Build error
Build error
Enhance PlannerAgent with entity recognition and coordination strategies - Implemented enhanced entity recognition and intent analysis in the PlannerAgent, allowing for improved query processing and coordination strategies. Updated the MusicRecommenderState model to include new fields for extracted entities, intent analysis, and conversation context. Enhanced logging and reasoning steps for better traceability during planning. Revised design documentation to reflect these changes and outline future enhancements.
00dd8ea | """ | |
| FastAPI Logging Middleware for BeatDebate | |
| Provides comprehensive logging for all API requests and responses including: | |
| - Request/response timing | |
| - Status codes and error tracking | |
| - Request IDs for tracing | |
| - Performance monitoring | |
| """ | |
| import time | |
| import uuid | |
| from typing import Callable | |
| from fastapi import Request, Response | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| try: | |
| from ..utils.logging_config import get_logger, log_api_request, set_request_context | |
| except RuntimeError: | |
| # Logging not initialized yet - will use fallback | |
| get_logger = None | |
| log_api_request = None | |
| set_request_context = None | |
| class LoggingMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Middleware to log all API requests and responses. | |
| Provides: | |
| - Request/response timing | |
| - Status code tracking | |
| - Error logging | |
| - Request ID generation for tracing | |
| - Performance monitoring | |
| """ | |
| def __init__(self, app, exclude_paths: list = None): | |
| """ | |
| Initialize logging middleware. | |
| Args: | |
| app: FastAPI application | |
| exclude_paths: List of paths to exclude from logging | |
| """ | |
| super().__init__(app) | |
| self.logger = get_logger("api.middleware") if get_logger else None | |
| self.exclude_paths = exclude_paths or ["/health", "/docs", "/openapi.json"] | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| """Process request and response with logging.""" | |
| # Skip logging for excluded paths | |
| if request.url.path in self.exclude_paths: | |
| return await call_next(request) | |
| # Generate request ID for tracing | |
| request_id = str(uuid.uuid4()) | |
| # Set request context for structured logging | |
| if set_request_context: | |
| set_request_context( | |
| request_id=request_id, | |
| user_id=self._extract_user_id(request) | |
| ) | |
| # Log request start | |
| start_time = time.time() | |
| if self.logger: | |
| self.logger.info( | |
| "api_request_start", | |
| request_id=request_id, | |
| method=request.method, | |
| url=str(request.url), | |
| headers=dict(request.headers), | |
| client_ip=self._get_client_ip(request) | |
| ) | |
| # Process request | |
| try: | |
| response = await call_next(request) | |
| # Calculate duration | |
| duration = time.time() - start_time | |
| # Log successful response | |
| if self.logger: | |
| self.logger.info( | |
| "api_request_complete", | |
| request_id=request_id, | |
| method=request.method, | |
| url=str(request.url), | |
| status_code=response.status_code, | |
| duration_seconds=round(duration, 4), | |
| response_size=response.headers.get("content-length", "unknown") | |
| ) | |
| # Log to performance metrics | |
| if log_api_request: | |
| log_api_request( | |
| method=request.method, | |
| url=str(request.url), | |
| status_code=response.status_code, | |
| duration=duration, | |
| request_id=request_id | |
| ) | |
| # Add request ID to response headers | |
| response.headers["X-Request-ID"] = request_id | |
| return response | |
| except Exception as e: | |
| # Calculate duration for error case | |
| duration = time.time() - start_time | |
| # Log error | |
| if self.logger: | |
| self.logger.error( | |
| "api_request_error", | |
| request_id=request_id, | |
| method=request.method, | |
| url=str(request.url), | |
| error_type=type(e).__name__, | |
| error_message=str(e), | |
| duration_seconds=round(duration, 4) | |
| ) | |
| # Re-raise the exception | |
| raise | |
| def _extract_user_id(self, request: Request) -> str: | |
| """Extract user ID from request if available.""" | |
| # Check for user ID in headers, query params, or session | |
| user_id = ( | |
| request.headers.get("X-User-ID") or | |
| request.query_params.get("user_id") | |
| ) | |
| # Try to get from session if available (only if SessionMiddleware is installed) | |
| try: | |
| if hasattr(request, "session") and "session" in request.scope: | |
| user_id = user_id or request.session.get("user_id") | |
| except (AttributeError, AssertionError): | |
| # SessionMiddleware not installed or session not available | |
| pass | |
| return user_id or "anonymous" | |
| def _get_client_ip(self, request: Request) -> str: | |
| """Get client IP address from request.""" | |
| # Check for forwarded headers first (for proxy/load balancer) | |
| forwarded_for = request.headers.get("X-Forwarded-For") | |
| if forwarded_for: | |
| return forwarded_for.split(",")[0].strip() | |
| real_ip = request.headers.get("X-Real-IP") | |
| if real_ip: | |
| return real_ip | |
| # Fall back to direct client IP | |
| return request.client.host if request.client else "unknown" | |
| class PerformanceLoggingMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Specialized middleware for performance monitoring. | |
| Tracks: | |
| - Slow requests (configurable threshold) | |
| - Memory usage | |
| - Database query times | |
| - External API call times | |
| """ | |
| def __init__(self, app, slow_request_threshold: float = 5.0): | |
| """ | |
| Initialize performance logging middleware. | |
| Args: | |
| app: FastAPI application | |
| slow_request_threshold: Time in seconds to consider a request slow | |
| """ | |
| super().__init__(app) | |
| self.logger = get_logger("performance") if get_logger else None | |
| self.slow_request_threshold = slow_request_threshold | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| """Process request with performance monitoring.""" | |
| start_time = time.time() | |
| try: | |
| response = await call_next(request) | |
| duration = time.time() - start_time | |
| # Log slow requests | |
| if duration > self.slow_request_threshold and self.logger: | |
| self.logger.warning( | |
| "slow_request", | |
| method=request.method, | |
| url=str(request.url), | |
| duration_seconds=round(duration, 4), | |
| threshold_seconds=self.slow_request_threshold | |
| ) | |
| # Log performance metrics for specific endpoints | |
| if self._should_track_performance(request.url.path) and self.logger: | |
| self.logger.info( | |
| "endpoint_performance", | |
| endpoint=request.url.path, | |
| method=request.method, | |
| duration_seconds=round(duration, 4), | |
| status_code=response.status_code | |
| ) | |
| return response | |
| except Exception as e: | |
| duration = time.time() - start_time | |
| if self.logger: | |
| self.logger.error( | |
| "request_exception", | |
| method=request.method, | |
| url=str(request.url), | |
| duration_seconds=round(duration, 4), | |
| error_type=type(e).__name__, | |
| error_message=str(e) | |
| ) | |
| raise | |
| def _should_track_performance(self, path: str) -> bool: | |
| """Determine if we should track performance for this endpoint.""" | |
| # Track performance for key endpoints | |
| tracked_endpoints = [ | |
| "/api/recommendations", | |
| "/api/chat", | |
| "/api/search" | |
| ] | |
| return any(tracked in path for tracked in tracked_endpoints) |