Spaces:
Sleeping
Sleeping
| """FastAPI routes for the Ask-the-Web Agent.""" | |
| from __future__ import annotations | |
| import time | |
| import uuid | |
| from typing import AsyncGenerator | |
| from fastapi import APIRouter, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| import json | |
| from src.api.schemas import ( | |
| QueryRequest, | |
| QueryResponse, | |
| SourceInfo, | |
| StreamingChunk, | |
| ErrorResponse, | |
| HealthResponse, | |
| ) | |
| from src.utils.config import get_settings | |
| from src.utils.logging import get_logger | |
| from src.utils.exceptions import ( | |
| AskTheWebError, | |
| ConfigurationError, | |
| LLMError, | |
| ToolError, | |
| ) | |
| router = APIRouter() | |
| logger = get_logger(__name__) | |
| # In-memory conversation storage (use Redis/DB in production) | |
| _conversations: dict[str, list[dict]] = {} | |
| async def health_check() -> HealthResponse: | |
| """Check the health of the service.""" | |
| settings = get_settings() | |
| components = { | |
| "llm_configured": bool(settings.openai_api_key or settings.anthropic_api_key), | |
| "search_configured": bool(settings.tavily_api_key), | |
| } | |
| return HealthResponse( | |
| status="healthy" if all(components.values()) else "degraded", | |
| version="1.0.0", | |
| components=components, | |
| ) | |
| async def query(request: QueryRequest) -> QueryResponse: | |
| """Process a user query and return an answer with sources. | |
| This endpoint accepts a natural language question and returns a | |
| comprehensive answer using web search and AI reasoning. | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Import agent here to avoid circular imports | |
| from src.agent.agent import AskTheWebAgent | |
| # Initialize agent | |
| agent = AskTheWebAgent() | |
| # Get conversation history if provided | |
| history = [] | |
| if request.conversation_id and request.conversation_id in _conversations: | |
| history = _conversations[request.conversation_id] | |
| # Process query | |
| response = await agent.query( | |
| question=request.query, | |
| history=history, | |
| enable_search=request.enable_search, | |
| max_sources=request.max_sources, | |
| ) | |
| # Generate or use existing conversation ID | |
| conversation_id = request.conversation_id or str(uuid.uuid4()) | |
| # Store in conversation history | |
| if conversation_id not in _conversations: | |
| _conversations[conversation_id] = [] | |
| _conversations[conversation_id].extend([ | |
| {"role": "user", "content": request.query}, | |
| {"role": "assistant", "content": response.answer}, | |
| ]) | |
| # Calculate processing time | |
| processing_time_ms = int((time.time() - start_time) * 1000) | |
| # Convert sources to API format | |
| sources = [ | |
| SourceInfo( | |
| title=s.get("title", ""), | |
| url=s.get("url", ""), | |
| snippet=s.get("snippet", ""), | |
| ) | |
| for s in response.sources | |
| ] | |
| return QueryResponse( | |
| answer=response.answer, | |
| sources=sources, | |
| follow_up_questions=response.follow_up_questions, | |
| confidence=response.confidence, | |
| conversation_id=conversation_id, | |
| processing_time_ms=processing_time_ms, | |
| metadata=response.metadata, | |
| ) | |
| except ConfigurationError as e: | |
| logger.error(f"Configuration error: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail={"error": str(e), "error_code": "CONFIGURATION_ERROR"}, | |
| ) | |
| except LLMError as e: | |
| logger.error(f"LLM error: {e}") | |
| raise HTTPException( | |
| status_code=503, | |
| detail={"error": str(e), "error_code": "LLM_ERROR"}, | |
| ) | |
| except ToolError as e: | |
| logger.error(f"Tool error: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail={"error": str(e), "error_code": "TOOL_ERROR"}, | |
| ) | |
| except AskTheWebError as e: | |
| logger.error(f"Agent error: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail={"error": str(e), "error_code": "AGENT_ERROR"}, | |
| ) | |
| except Exception as e: | |
| logger.exception(f"Unexpected error: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail={"error": "An unexpected error occurred", "error_code": "INTERNAL_ERROR"}, | |
| ) | |
| async def query_stream(request: QueryRequest) -> StreamingResponse: | |
| """Process a query with streaming response. | |
| Returns a stream of JSON chunks as the answer is generated. | |
| """ | |
| async def generate() -> AsyncGenerator[str, None]: | |
| start_time = time.time() | |
| try: | |
| from src.agent.agent import AskTheWebAgent | |
| agent = AskTheWebAgent() | |
| # Send start chunk | |
| yield json.dumps(StreamingChunk( | |
| type="start", | |
| content="", | |
| metadata={"query": request.query}, | |
| ).model_dump()) + "\n" | |
| # Get conversation history | |
| history = [] | |
| if request.conversation_id and request.conversation_id in _conversations: | |
| history = _conversations[request.conversation_id] | |
| # Process query | |
| response = await agent.query( | |
| question=request.query, | |
| history=history, | |
| enable_search=request.enable_search, | |
| max_sources=request.max_sources, | |
| ) | |
| # Stream answer in chunks (simulate streaming for now) | |
| words = response.answer.split() | |
| chunk_size = 5 | |
| for i in range(0, len(words), chunk_size): | |
| chunk_words = words[i:i + chunk_size] | |
| yield json.dumps(StreamingChunk( | |
| type="content", | |
| content=" ".join(chunk_words) + " ", | |
| ).model_dump()) + "\n" | |
| # Send sources | |
| for source in response.sources: | |
| yield json.dumps(StreamingChunk( | |
| type="source", | |
| source=SourceInfo( | |
| title=source.get("title", ""), | |
| url=source.get("url", ""), | |
| snippet=source.get("snippet", ""), | |
| ), | |
| ).model_dump()) + "\n" | |
| # Send done chunk | |
| processing_time_ms = int((time.time() - start_time) * 1000) | |
| yield json.dumps(StreamingChunk( | |
| type="done", | |
| metadata={ | |
| "confidence": response.confidence, | |
| "follow_up_questions": response.follow_up_questions, | |
| "processing_time_ms": processing_time_ms, | |
| }, | |
| ).model_dump()) + "\n" | |
| except Exception as e: | |
| logger.exception(f"Streaming error: {e}") | |
| yield json.dumps({ | |
| "type": "error", | |
| "error": str(e), | |
| }) + "\n" | |
| return StreamingResponse( | |
| generate(), | |
| media_type="application/x-ndjson", | |
| ) | |
| async def delete_conversation(conversation_id: str) -> dict: | |
| """Delete a conversation history.""" | |
| if conversation_id in _conversations: | |
| del _conversations[conversation_id] | |
| return {"message": "Conversation deleted", "conversation_id": conversation_id} | |
| else: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"error": "Conversation not found", "error_code": "NOT_FOUND"}, | |
| ) | |
| async def get_conversation(conversation_id: str) -> dict: | |
| """Get conversation history.""" | |
| if conversation_id in _conversations: | |
| return { | |
| "conversation_id": conversation_id, | |
| "messages": _conversations[conversation_id], | |
| } | |
| else: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"error": "Conversation not found", "error_code": "NOT_FOUND"}, | |
| ) | |