ask-the-web-agent / src /api /routes.py
debashis2007's picture
Upload folder using huggingface_hub
75bea1c verified
"""FastAPI routes for the Ask-the-Web Agent."""
from __future__ import annotations
import time
import uuid
from typing import AsyncGenerator
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
import json
from src.api.schemas import (
QueryRequest,
QueryResponse,
SourceInfo,
StreamingChunk,
ErrorResponse,
HealthResponse,
)
from src.utils.config import get_settings
from src.utils.logging import get_logger
from src.utils.exceptions import (
AskTheWebError,
ConfigurationError,
LLMError,
ToolError,
)
router = APIRouter()
logger = get_logger(__name__)
# In-memory conversation storage (use Redis/DB in production)
_conversations: dict[str, list[dict]] = {}
@router.get("/health", response_model=HealthResponse, tags=["System"])
async def health_check() -> HealthResponse:
"""Check the health of the service."""
settings = get_settings()
components = {
"llm_configured": bool(settings.openai_api_key or settings.anthropic_api_key),
"search_configured": bool(settings.tavily_api_key),
}
return HealthResponse(
status="healthy" if all(components.values()) else "degraded",
version="1.0.0",
components=components,
)
@router.post("/query", response_model=QueryResponse, tags=["Query"])
async def query(request: QueryRequest) -> QueryResponse:
"""Process a user query and return an answer with sources.
This endpoint accepts a natural language question and returns a
comprehensive answer using web search and AI reasoning.
"""
start_time = time.time()
try:
# Import agent here to avoid circular imports
from src.agent.agent import AskTheWebAgent
# Initialize agent
agent = AskTheWebAgent()
# Get conversation history if provided
history = []
if request.conversation_id and request.conversation_id in _conversations:
history = _conversations[request.conversation_id]
# Process query
response = await agent.query(
question=request.query,
history=history,
enable_search=request.enable_search,
max_sources=request.max_sources,
)
# Generate or use existing conversation ID
conversation_id = request.conversation_id or str(uuid.uuid4())
# Store in conversation history
if conversation_id not in _conversations:
_conversations[conversation_id] = []
_conversations[conversation_id].extend([
{"role": "user", "content": request.query},
{"role": "assistant", "content": response.answer},
])
# Calculate processing time
processing_time_ms = int((time.time() - start_time) * 1000)
# Convert sources to API format
sources = [
SourceInfo(
title=s.get("title", ""),
url=s.get("url", ""),
snippet=s.get("snippet", ""),
)
for s in response.sources
]
return QueryResponse(
answer=response.answer,
sources=sources,
follow_up_questions=response.follow_up_questions,
confidence=response.confidence,
conversation_id=conversation_id,
processing_time_ms=processing_time_ms,
metadata=response.metadata,
)
except ConfigurationError as e:
logger.error(f"Configuration error: {e}")
raise HTTPException(
status_code=500,
detail={"error": str(e), "error_code": "CONFIGURATION_ERROR"},
)
except LLMError as e:
logger.error(f"LLM error: {e}")
raise HTTPException(
status_code=503,
detail={"error": str(e), "error_code": "LLM_ERROR"},
)
except ToolError as e:
logger.error(f"Tool error: {e}")
raise HTTPException(
status_code=500,
detail={"error": str(e), "error_code": "TOOL_ERROR"},
)
except AskTheWebError as e:
logger.error(f"Agent error: {e}")
raise HTTPException(
status_code=500,
detail={"error": str(e), "error_code": "AGENT_ERROR"},
)
except Exception as e:
logger.exception(f"Unexpected error: {e}")
raise HTTPException(
status_code=500,
detail={"error": "An unexpected error occurred", "error_code": "INTERNAL_ERROR"},
)
@router.post("/query/stream", tags=["Query"])
async def query_stream(request: QueryRequest) -> StreamingResponse:
"""Process a query with streaming response.
Returns a stream of JSON chunks as the answer is generated.
"""
async def generate() -> AsyncGenerator[str, None]:
start_time = time.time()
try:
from src.agent.agent import AskTheWebAgent
agent = AskTheWebAgent()
# Send start chunk
yield json.dumps(StreamingChunk(
type="start",
content="",
metadata={"query": request.query},
).model_dump()) + "\n"
# Get conversation history
history = []
if request.conversation_id and request.conversation_id in _conversations:
history = _conversations[request.conversation_id]
# Process query
response = await agent.query(
question=request.query,
history=history,
enable_search=request.enable_search,
max_sources=request.max_sources,
)
# Stream answer in chunks (simulate streaming for now)
words = response.answer.split()
chunk_size = 5
for i in range(0, len(words), chunk_size):
chunk_words = words[i:i + chunk_size]
yield json.dumps(StreamingChunk(
type="content",
content=" ".join(chunk_words) + " ",
).model_dump()) + "\n"
# Send sources
for source in response.sources:
yield json.dumps(StreamingChunk(
type="source",
source=SourceInfo(
title=source.get("title", ""),
url=source.get("url", ""),
snippet=source.get("snippet", ""),
),
).model_dump()) + "\n"
# Send done chunk
processing_time_ms = int((time.time() - start_time) * 1000)
yield json.dumps(StreamingChunk(
type="done",
metadata={
"confidence": response.confidence,
"follow_up_questions": response.follow_up_questions,
"processing_time_ms": processing_time_ms,
},
).model_dump()) + "\n"
except Exception as e:
logger.exception(f"Streaming error: {e}")
yield json.dumps({
"type": "error",
"error": str(e),
}) + "\n"
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
)
@router.delete("/conversation/{conversation_id}", tags=["Conversation"])
async def delete_conversation(conversation_id: str) -> dict:
"""Delete a conversation history."""
if conversation_id in _conversations:
del _conversations[conversation_id]
return {"message": "Conversation deleted", "conversation_id": conversation_id}
else:
raise HTTPException(
status_code=404,
detail={"error": "Conversation not found", "error_code": "NOT_FOUND"},
)
@router.get("/conversation/{conversation_id}", tags=["Conversation"])
async def get_conversation(conversation_id: str) -> dict:
"""Get conversation history."""
if conversation_id in _conversations:
return {
"conversation_id": conversation_id,
"messages": _conversations[conversation_id],
}
else:
raise HTTPException(
status_code=404,
detail={"error": "Conversation not found", "error_code": "NOT_FOUND"},
)