Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request, Security, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import Dict, List, Optional, Union, Any | |
| from pydantic import BaseModel, Field | |
| from datetime import datetime | |
| import logging | |
| import json | |
| import os | |
| from dotenv import load_dotenv | |
| from dify_client_python.dify_client import models | |
| from sse_starlette.sse import EventSourceResponse | |
| import httpx | |
| from json_parser import SSEParser | |
| from logger_config import setup_logger | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.responses import JSONResponse | |
| from response_formatter import ResponseFormatter | |
| import traceback | |
| from fastapi.security.api_key import APIKeyHeader, APIKey | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.openapi.docs import get_swagger_ui_html | |
| from fastapi.responses import HTMLResponse | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Add these constants near the top of the file after imports | |
| API_KEY_NAME = "X-API-Key" | |
| API_KEY = os.getenv("CLIENT_API_KEY") # Add this to your .env file | |
| api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True) | |
| class AgentOutput(BaseModel): | |
| """Structured output from agent processing""" | |
| thought_content: str | |
| observation: Optional[str] | |
| tool_outputs: List[Dict] | |
| citations: List[Dict] | |
| metadata: Dict | |
| raw_response: str | |
| class AgentRequest(BaseModel): | |
| """Enhanced request model with additional parameters""" | |
| query: str | |
| conversation_id: Optional[str] = None | |
| stream: bool = True | |
| inputs: Dict = {} | |
| files: List = [] | |
| user: str = "default_user" | |
| response_mode: str = "streaming" | |
| class AgentProcessor: | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| # Update API base to use environment variable with fallback | |
| self.api_base = os.getenv( | |
| "API_BASE_URL", | |
| "https://rag-engine.go-yamamoto.com/v1" | |
| ) | |
| self.formatter = ResponseFormatter() | |
| self.client = httpx.AsyncClient(timeout=60.0) | |
| self.logger = setup_logger("agent_processor") | |
| async def log_request_details( | |
| self, | |
| request: AgentRequest, | |
| start_time: datetime | |
| ) -> None: | |
| """Log detailed request information""" | |
| self.logger.debug( | |
| "Request details: \n" | |
| f"Query: {request.query}\n" | |
| f"User: {request.user}\n" | |
| f"Conversation ID: {request.conversation_id}\n" | |
| f"Stream mode: {request.stream}\n" | |
| f"Start time: {start_time}\n" | |
| f"Inputs: {request.inputs}\n" | |
| f"Files: {len(request.files)} files attached" | |
| ) | |
| async def log_error( | |
| self, | |
| error: Exception, | |
| context: Optional[Dict] = None | |
| ) -> None: | |
| """Log detailed error information""" | |
| error_msg = ( | |
| f"Error type: {type(error).__name__}\n" | |
| f"Error message: {str(error)}\n" | |
| f"Stack trace:\n{traceback.format_exc()}\n" | |
| ) | |
| if context: | |
| error_msg += f"Context:\n{json.dumps(context, indent=2)}" | |
| self.logger.error(error_msg) | |
| async def cleanup(self): | |
| """Cleanup method to properly close client""" | |
| await self.client.aclose() | |
| async def process_stream(self, request: AgentRequest): | |
| start_time = datetime.now() | |
| await self.log_request_details(request, start_time) | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream" | |
| } | |
| chat_request = { | |
| "query": request.query, | |
| "inputs": request.inputs, | |
| "response_mode": "streaming" if request.stream else "blocking", | |
| "user": request.user, | |
| "conversation_id": request.conversation_id, | |
| "files": request.files | |
| } | |
| async def event_generator(): | |
| parser = SSEParser() | |
| citations = [] | |
| metadata = {} | |
| try: | |
| async with self.client.stream( | |
| "POST", | |
| f"{self.api_base}/chat-messages", | |
| headers=headers, | |
| json=chat_request | |
| ) as response: | |
| self.logger.debug( | |
| f"Stream connection established\n" | |
| f"Status: {response.status_code}\n" | |
| f"Headers: {dict(response.headers)}" | |
| ) | |
| buffer = "" | |
| async for line in response.aiter_lines(): | |
| if not line.strip(): | |
| continue | |
| self.logger.debug(f"Raw SSE line: {line}") | |
| if "data:" in line: | |
| try: | |
| data = line.split("data:", 1)[1].strip() | |
| parsed = json.loads(data) | |
| if parsed.get("event") == "message_end": | |
| citations = parsed.get("retriever_resources", []) | |
| metadata = parsed.get("metadata", {}) | |
| self.logger.debug( | |
| f"Message end event:\n" | |
| f"Citations: {citations}\n" | |
| f"Metadata: {metadata}" | |
| ) | |
| formatted = self.format_terminal_output( | |
| parsed, | |
| citations=citations, | |
| metadata=metadata | |
| ) | |
| if formatted: | |
| self.logger.info(formatted) | |
| except Exception as e: | |
| await self.log_error( | |
| e, | |
| {"line": line, "event": "parse_data"} | |
| ) | |
| buffer += line + "\n" | |
| if line.startswith("data:") or buffer.strip().endswith("}"): | |
| try: | |
| processed_response = parser.parse_sse_event(buffer) | |
| if processed_response and isinstance(processed_response, dict): | |
| cleaned_response = self.clean_response(processed_response) | |
| if cleaned_response: | |
| xml_content = cleaned_response.get("content", "") | |
| yield f"data: {xml_content}\n\n" | |
| except Exception as parse_error: | |
| await self.log_error( | |
| parse_error, | |
| {"buffer": buffer, "event": "process_buffer"} | |
| ) | |
| error_xml = ( | |
| f"<agent_response>" | |
| f"<error>{str(parse_error)}</error>" | |
| f"</agent_response>" | |
| ) | |
| yield f"data: {error_xml}\n\n" | |
| finally: | |
| buffer = "" | |
| except httpx.ConnectError as e: | |
| await self.log_error(e, {"event": "connection_error"}) | |
| error_xml = ( | |
| f"<agent_response>" | |
| f"<error>Connection error: {str(e)}</error>" | |
| f"</agent_response>" | |
| ) | |
| yield f"data: {error_xml}\n\n" | |
| except Exception as e: | |
| await self.log_error(e, {"event": "stream_error"}) | |
| error_xml = ( | |
| f"<agent_response>" | |
| f"<error>Streaming error: {str(e)}</error>" | |
| f"</agent_response>" | |
| ) | |
| yield f"data: {error_xml}\n\n" | |
| finally: | |
| end_time = datetime.now() | |
| duration = (end_time - start_time).total_seconds() | |
| self.logger.info(f"Request completed in {duration:.2f} seconds") | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| "Access-Control-Allow-Origin": "*" | |
| } | |
| ) | |
| def format_terminal_output( | |
| self, | |
| response: Dict, | |
| citations: List[Dict] = None, | |
| metadata: Dict = None | |
| ) -> Optional[str]: | |
| """Format response for terminal output""" | |
| event_type = response.get("event") | |
| if event_type == "agent_thought": | |
| thought = response.get("thought", "") | |
| observation = response.get("observation", "") | |
| terminal_output, _ = self.formatter.format_thought( | |
| thought, | |
| observation, | |
| citations=citations, | |
| metadata=metadata | |
| ) | |
| return terminal_output | |
| elif event_type == "agent_message": | |
| message = response.get("answer", "") | |
| terminal_output, _ = self.formatter.format_message(message) | |
| return terminal_output | |
| elif event_type == "error": | |
| error = response.get("error", "Unknown error") | |
| terminal_output, _ = self.formatter.format_error(error) | |
| return terminal_output | |
| return None | |
| def clean_response(self, response: Dict) -> Optional[Dict]: | |
| """Clean and transform the response for frontend consumption""" | |
| try: | |
| event_type = response.get("event") | |
| if not event_type: | |
| return None | |
| # Handle different event types | |
| if event_type == "agent_thought": | |
| thought = response.get("thought", "") | |
| observation = response.get("observation", "") | |
| _, xml_output = self.formatter.format_thought(thought, observation) | |
| return { | |
| "type": "thought", | |
| "content": xml_output | |
| } | |
| elif event_type == "agent_message": | |
| message = response.get("answer", "") | |
| _, xml_output = self.formatter.format_message(message) | |
| return { | |
| "type": "message", | |
| "content": xml_output | |
| } | |
| elif event_type == "error": | |
| error = response.get("error", "Unknown error") | |
| _, xml_output = self.formatter.format_error(error) | |
| return { | |
| "type": "error", | |
| "content": xml_output | |
| } | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error cleaning response: {str(e)}") | |
| return None | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Agent API", | |
| description="API requiring X-API-Key authentication", | |
| version="1.0.0", | |
| docs_url=None, # Disable default docs | |
| openapi_tags=[{"name": "agent", "description": "Agent endpoints"}], | |
| ) | |
| agent_processor = None | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Add security scheme | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Add security scheme | |
| app.add_security_requirement({"ApiKeyAuth": []}) | |
| app.openapi_schema = None # Reset OpenAPI schema | |
| # Define the security scheme | |
| security_scheme = { | |
| "ApiKeyAuth": { | |
| "type": "apiKey", | |
| "in": "header", | |
| "name": "X-API-Key", | |
| "description": "API key required for authentication" | |
| } | |
| } | |
| app.openapi_components = {"securitySchemes": security_scheme} | |
| async def startup_event(): | |
| global agent_processor | |
| api_key = os.getenv("DIFY_API_KEY", "app-kVHTrZzEmFXEBfyXOi4rro7M") | |
| agent_processor = AgentProcessor(api_key=api_key) | |
| async def shutdown_event(): | |
| global agent_processor | |
| if agent_processor: | |
| await agent_processor.cleanup() | |
| # Add this function before your routes | |
| async def get_api_key( | |
| api_key_header: str = Security(api_key_header) | |
| ) -> APIKey: | |
| """Validate API key from header.""" | |
| if not API_KEY: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="API key configuration is missing on server" | |
| ) | |
| if api_key_header == API_KEY: | |
| return api_key_header | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Invalid or missing API key" | |
| ) | |
| # Update your route to require API key | |
| async def process_agent_request( | |
| request: AgentRequest, | |
| api_key: APIKey = Security(api_key_header, scopes=[]) | |
| ): | |
| try: | |
| logger.info(f"Processing agent request: {request.query}") | |
| return await agent_processor.process_stream(request) | |
| except Exception as e: | |
| logger.error(f"Error in agent request processing: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def error_handling_middleware(request: Request, call_next): | |
| try: | |
| response = await call_next(request) | |
| return response | |
| except Exception as e: | |
| logger.error(f"Unhandled error: {str(e)}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": "Internal server error occurred"} | |
| ) | |
| # Add host and port parameters to the launch | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run( | |
| "api:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=True | |
| ) | |
| # Add custom docs endpoint | |
| async def custom_swagger_ui_html(): | |
| return get_swagger_ui_html( | |
| openapi_url=app.openapi_url, | |
| title=app.title + " - Swagger UI", | |
| oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url, | |
| swagger_js_url="https://unpkg.com/swagger-ui-dist@5.9.0/swagger-ui-bundle.js", | |
| swagger_css_url="https://unpkg.com/swagger-ui-dist@5.9.0/swagger-ui.css", | |
| swagger_favicon_url="https://fastapi.tiangolo.com/img/favicon.png", | |
| extra_html=""" | |
| <style> | |
| /* Dark theme with cool colors */ | |
| :root { | |
| --primary-color: #00b4d8; | |
| --secondary-color: #90e0ef; | |
| --background-color: #0d1117; | |
| --text-color: #e6edf3; | |
| --border-color: #30363d; | |
| } | |
| body { | |
| background-color: var(--background-color); | |
| color: var(--text-color); | |
| } | |
| .swagger-ui { | |
| background-color: var(--background-color); | |
| color: var(--text-color); | |
| } | |
| /* Headers and text */ | |
| .swagger-ui .info .title, | |
| .swagger-ui .info .base-url, | |
| .swagger-ui .info li, | |
| .swagger-ui .info p, | |
| .swagger-ui .info table { | |
| color: var(--text-color); | |
| } | |
| /* Operation buttons */ | |
| .swagger-ui .opblock.opblock-post { | |
| background: rgba(0, 180, 216, 0.1); | |
| border-color: var(--primary-color); | |
| } | |
| .swagger-ui .opblock.opblock-post .opblock-summary-method { | |
| background: var(--primary-color); | |
| } | |
| /* Authorize button */ | |
| .swagger-ui .btn.authorize { | |
| background: var(--primary-color); | |
| border-color: var(--primary-color); | |
| color: white; | |
| } | |
| .swagger-ui .btn.authorize svg { | |
| fill: white; | |
| } | |
| /* Schema sections */ | |
| .swagger-ui .model-box { | |
| background: rgba(48, 54, 61, 0.4); | |
| } | |
| .swagger-ui .model { | |
| color: var(--text-color); | |
| } | |
| /* Try it out section */ | |
| .swagger-ui textarea, | |
| .swagger-ui input[type=text] { | |
| background: var(--background-color); | |
| color: var(--text-color); | |
| border-color: var(--border-color); | |
| } | |
| /* Response section */ | |
| .swagger-ui .responses-table th, | |
| .swagger-ui .responses-table td { | |
| color: var(--text-color); | |
| border-color: var(--border-color); | |
| } | |
| /* Scrollbar */ | |
| ::-webkit-scrollbar { | |
| width: 8px; | |
| height: 8px; | |
| } | |
| ::-webkit-scrollbar-track { | |
| background: var(--background-color); | |
| } | |
| ::-webkit-scrollbar-thumb { | |
| background: var(--primary-color); | |
| border-radius: 4px; | |
| } | |
| /* Code blocks */ | |
| .swagger-ui .highlight-code { | |
| background-color: #1b1f24; | |
| } | |
| /* Modal dialogs */ | |
| .swagger-ui .dialog-ux .modal-ux { | |
| background: var(--background-color); | |
| border-color: var(--border-color); | |
| } | |
| .swagger-ui .dialog-ux .modal-ux-header h3 { | |
| color: var(--text-color); | |
| } | |
| /* Tables */ | |
| .swagger-ui table thead tr td, | |
| .swagger-ui table thead tr th { | |
| color: var(--text-color); | |
| border-color: var(--border-color); | |
| } | |
| /* Links */ | |
| .swagger-ui a { | |
| color: var(--primary-color); | |
| } | |
| /* Schema dropdowns */ | |
| .swagger-ui select { | |
| background: var(--background-color); | |
| color: var(--text-color); | |
| border-color: var(--border-color); | |
| } | |
| </style> | |
| """ | |
| ) |