Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import Dict, List, Optional, Union, Any | |
| from pydantic import BaseModel, Field | |
| from datetime import datetime | |
| import logging | |
| import json | |
| import os | |
| from dotenv import load_dotenv | |
| from dify_client_python.dify_client import models | |
| from sse_starlette.sse import EventSourceResponse | |
| import httpx | |
| from json_parser import SSEParser | |
| from logger_config import setup_logger | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.responses import JSONResponse | |
| from response_formatter import ResponseFormatter | |
| import traceback | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class AgentOutput(BaseModel): | |
| """Structured output from agent processing""" | |
| thought_content: str | |
| observation: Optional[str] | |
| tool_outputs: List[Dict] | |
| citations: List[Dict] | |
| metadata: Dict | |
| raw_response: str | |
| class AgentRequest(BaseModel): | |
| """Enhanced request model with additional parameters""" | |
| query: str | |
| conversation_id: Optional[str] = None | |
| stream: bool = True | |
| inputs: Dict = {} | |
| files: List = [] | |
| user: str = "default_user" | |
| response_mode: str = "streaming" | |
| class AgentProcessor: | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| # Update API base to use environment variable with fallback | |
| self.api_base = os.getenv( | |
| "API_BASE_URL", | |
| "https://rag-engine.go-yamamoto.com/v1" | |
| ) | |
| self.formatter = ResponseFormatter() | |
| self.client = httpx.AsyncClient(timeout=60.0) | |
| self.logger = setup_logger("agent_processor") | |
| async def log_request_details( | |
| self, | |
| request: AgentRequest, | |
| start_time: datetime | |
| ) -> None: | |
| """Log detailed request information""" | |
| self.logger.debug( | |
| "Request details: \n" | |
| f"Query: {request.query}\n" | |
| f"User: {request.user}\n" | |
| f"Conversation ID: {request.conversation_id}\n" | |
| f"Stream mode: {request.stream}\n" | |
| f"Start time: {start_time}\n" | |
| f"Inputs: {request.inputs}\n" | |
| f"Files: {len(request.files)} files attached" | |
| ) | |
| async def log_error( | |
| self, | |
| error: Exception, | |
| context: Optional[Dict] = None | |
| ) -> None: | |
| """Log detailed error information""" | |
| error_msg = ( | |
| f"Error type: {type(error).__name__}\n" | |
| f"Error message: {str(error)}\n" | |
| f"Stack trace:\n{traceback.format_exc()}\n" | |
| ) | |
| if context: | |
| error_msg += f"Context:\n{json.dumps(context, indent=2)}" | |
| self.logger.error(error_msg) | |
| async def cleanup(self): | |
| """Cleanup method to properly close client""" | |
| await self.client.aclose() | |
| async def process_stream(self, request: AgentRequest): | |
| start_time = datetime.now() | |
| await self.log_request_details(request, start_time) | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream" | |
| } | |
| chat_request = { | |
| "query": request.query, | |
| "inputs": request.inputs, | |
| "response_mode": "streaming" if request.stream else "blocking", | |
| "user": request.user, | |
| "conversation_id": request.conversation_id, | |
| "files": request.files | |
| } | |
| async def event_generator(): | |
| parser = SSEParser() | |
| citations = [] | |
| metadata = {} | |
| try: | |
| async with self.client.stream( | |
| "POST", | |
| f"{self.api_base}/chat-messages", | |
| headers=headers, | |
| json=chat_request | |
| ) as response: | |
| self.logger.debug( | |
| f"Stream connection established\n" | |
| f"Status: {response.status_code}\n" | |
| f"Headers: {dict(response.headers)}" | |
| ) | |
| buffer = "" | |
| async for line in response.aiter_lines(): | |
| if not line.strip(): | |
| continue | |
| self.logger.debug(f"Raw SSE line: {line}") | |
| if "data:" in line: | |
| try: | |
| data = line.split("data:", 1)[1].strip() | |
| parsed = json.loads(data) | |
| if parsed.get("event") == "message_end": | |
| citations = parsed.get("retriever_resources", []) | |
| metadata = parsed.get("metadata", {}) | |
| self.logger.debug( | |
| f"Message end event:\n" | |
| f"Citations: {citations}\n" | |
| f"Metadata: {metadata}" | |
| ) | |
| formatted = self.format_terminal_output( | |
| parsed, | |
| citations=citations, | |
| metadata=metadata | |
| ) | |
| if formatted: | |
| self.logger.info(formatted) | |
| except Exception as e: | |
| await self.log_error( | |
| e, | |
| {"line": line, "event": "parse_data"} | |
| ) | |
| buffer += line + "\n" | |
| if line.startswith("data:") or buffer.strip().endswith("}"): | |
| try: | |
| processed_response = parser.parse_sse_event(buffer) | |
| if processed_response and isinstance(processed_response, dict): | |
| cleaned_response = self.clean_response(processed_response) | |
| if cleaned_response: | |
| xml_content = cleaned_response.get("content", "") | |
| yield f"data: {xml_content}\n\n" | |
| except Exception as parse_error: | |
| await self.log_error( | |
| parse_error, | |
| {"buffer": buffer, "event": "process_buffer"} | |
| ) | |
| error_xml = ( | |
| f"<agent_response>" | |
| f"<error>{str(parse_error)}</error>" | |
| f"</agent_response>" | |
| ) | |
| yield f"data: {error_xml}\n\n" | |
| finally: | |
| buffer = "" | |
| except httpx.ConnectError as e: | |
| await self.log_error(e, {"event": "connection_error"}) | |
| error_xml = ( | |
| f"<agent_response>" | |
| f"<error>Connection error: {str(e)}</error>" | |
| f"</agent_response>" | |
| ) | |
| yield f"data: {error_xml}\n\n" | |
| except Exception as e: | |
| await self.log_error(e, {"event": "stream_error"}) | |
| error_xml = ( | |
| f"<agent_response>" | |
| f"<error>Streaming error: {str(e)}</error>" | |
| f"</agent_response>" | |
| ) | |
| yield f"data: {error_xml}\n\n" | |
| finally: | |
| end_time = datetime.now() | |
| duration = (end_time - start_time).total_seconds() | |
| self.logger.info(f"Request completed in {duration:.2f} seconds") | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| "Access-Control-Allow-Origin": "*" | |
| } | |
| ) | |
| def format_terminal_output( | |
| self, | |
| response: Dict, | |
| citations: List[Dict] = None, | |
| metadata: Dict = None | |
| ) -> Optional[str]: | |
| """Format response for terminal output""" | |
| event_type = response.get("event") | |
| if event_type == "agent_thought": | |
| thought = response.get("thought", "") | |
| observation = response.get("observation", "") | |
| terminal_output, _ = self.formatter.format_thought( | |
| thought, | |
| observation, | |
| citations=citations, | |
| metadata=metadata | |
| ) | |
| return terminal_output | |
| elif event_type == "agent_message": | |
| message = response.get("answer", "") | |
| terminal_output, _ = self.formatter.format_message(message) | |
| return terminal_output | |
| elif event_type == "error": | |
| error = response.get("error", "Unknown error") | |
| terminal_output, _ = self.formatter.format_error(error) | |
| return terminal_output | |
| return None | |
| def clean_response(self, response: Dict) -> Optional[Dict]: | |
| """Clean and transform the response for frontend consumption""" | |
| try: | |
| event_type = response.get("event") | |
| if not event_type: | |
| return None | |
| # Handle different event types | |
| if event_type == "agent_thought": | |
| thought = response.get("thought", "") | |
| observation = response.get("observation", "") | |
| _, xml_output = self.formatter.format_thought(thought, observation) | |
| return { | |
| "type": "thought", | |
| "content": xml_output | |
| } | |
| elif event_type == "agent_message": | |
| message = response.get("answer", "") | |
| _, xml_output = self.formatter.format_message(message) | |
| return { | |
| "type": "message", | |
| "content": xml_output | |
| } | |
| elif event_type == "error": | |
| error = response.get("error", "Unknown error") | |
| _, xml_output = self.formatter.format_error(error) | |
| return { | |
| "type": "error", | |
| "content": xml_output | |
| } | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error cleaning response: {str(e)}") | |
| return None | |
| # Initialize FastAPI app | |
| app = FastAPI(docs_url="/", redoc_url=None) | |
| agent_processor = None | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| global agent_processor | |
| api_key = os.getenv("DIFY_API_KEY", "app-kVHTrZzEmFXEBfyXOi4rro7M") | |
| agent_processor = AgentProcessor(api_key=api_key) | |
| async def shutdown_event(): | |
| global agent_processor | |
| if agent_processor: | |
| await agent_processor.cleanup() | |
| async def process_agent_request(request: AgentRequest): | |
| try: | |
| logger.info(f"Processing agent request: {request.query}") | |
| return await agent_processor.process_stream(request) | |
| except Exception as e: | |
| logger.error(f"Error in agent request processing: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def error_handling_middleware(request: Request, call_next): | |
| try: | |
| response = await call_next(request) | |
| return response | |
| except Exception as e: | |
| logger.error(f"Unhandled error: {str(e)}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": "Internal server error occurred"} | |
| ) | |
| # Add host and port parameters to the launch | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run( | |
| "api:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=True | |
| ) |