debashis2007's picture
Fix: Add query() method to AskTheWebAgent
a2978eb verified
from __future__ import annotations
"""Main agent class that ties everything together."""
import json
from typing import Any, AsyncIterator
from src.agent.controller import AgentController
from src.agent.models import AgentResponse, AgentState, Citation
from src.llm import LLMClient, Message, MessageRole, create_llm_client
from src.llm.prompts import format_prompt, get_system_prompt, PromptNames
from src.tools.base import create_default_registry, ToolRegistry
from src.utils.config import settings
from src.utils.exceptions import AgentError
from src.utils.logging import get_logger
logger = get_logger(__name__)
class AskTheWebAgent:
"""Main Ask-the-Web Agent class."""
def __init__(
self,
llm_client: LLMClient | None = None,
tool_registry: ToolRegistry | None = None,
):
"""Initialize the agent.
Args:
llm_client: Optional LLM client (creates default if not provided)
tool_registry: Optional tool registry (creates default if not provided)
"""
self.llm = llm_client or create_llm_client()
self.tools = tool_registry or create_default_registry()
self.controller = AgentController(self.llm, self.tools)
self.system_prompt = get_system_prompt()
async def ask(
self,
query: str,
context: dict[str, Any] | None = None,
use_react: bool = True,
) -> AgentResponse:
"""Ask the agent a question.
Args:
query: User's question
context: Optional conversation context
use_react: Whether to use ReACT reasoning (default True)
Returns:
AgentResponse with answer and sources
"""
logger.info(f"Processing query: {query}")
# Initialize state
state = AgentState(query=query)
if context:
state.working_memory.update(context)
try:
# 1. Parse intent
state.intent = await self.controller.parse_intent(query)
logger.info(f"Parsed intent: {state.intent.intent_type.value}")
# 2. Plan workflow
state.plan = await self.controller.plan_workflow(query, state.intent)
logger.info(f"Selected strategy: {state.plan.strategy.value}")
# 3. Execute based on strategy
if use_react and state.intent.requires_web_search:
# Use ReACT loop for complex queries
answer, thought_history = await self.controller.run_react_loop(state)
state.thought_history = thought_history
# Extract sources from tool results
sources = self._extract_sources(state)
# Generate follow-up questions
follow_ups = await self._generate_follow_ups(query, answer)
else:
# Simple execution for direct answers
answer = await self._direct_answer(query, state)
sources = []
follow_ups = []
# 4. Reflect on response quality (optional)
confidence = 0.8
if settings.enable_reflection:
confidence = await self._reflect_on_response(query, answer, sources)
return AgentResponse(
answer=answer,
confidence=confidence,
sources=sources,
reasoning_summary=self._summarize_reasoning(state),
follow_up_questions=follow_ups,
metadata={
"intent": state.intent.intent_type.value if state.intent else None,
"strategy": state.plan.strategy.value if state.plan else None,
"iterations": len(state.thought_history),
},
)
except Exception as e:
logger.error(f"Agent error: {e}")
raise AgentError(f"Failed to process query: {e}") from e
async def ask_stream(
self,
query: str,
context: dict[str, Any] | None = None,
) -> AsyncIterator[str]:
"""Stream a response to a question.
Args:
query: User's question
context: Optional conversation context
Yields:
Response chunks as they become available
"""
logger.info(f"Streaming query: {query}")
# For streaming, we'll do a simplified flow
state = AgentState(query=query)
# Parse intent first
state.intent = await self.controller.parse_intent(query)
if state.intent.requires_web_search:
# Do search first
result = await self.tools.execute(
"web_search", query=query, num_results=settings.max_search_results
)
if result.success:
search_results = result.data.get("results", [])
context_text = self._format_search_results(search_results)
else:
context_text = ""
else:
context_text = ""
# Stream the response
messages = [
Message(role=MessageRole.SYSTEM, content=self.system_prompt),
Message(
role=MessageRole.USER,
content=f"Question: {query}\n\nContext from web search:\n{context_text}\n\nProvide a comprehensive answer with citations.",
),
]
async for chunk in self.llm.chat_stream(messages):
yield chunk
async def _direct_answer(self, query: str, state: AgentState) -> str:
"""Generate a direct answer without web search.
Args:
query: User's question
state: Current state
Returns:
Answer string
"""
messages = [
Message(role=MessageRole.SYSTEM, content=self.system_prompt),
Message(role=MessageRole.USER, content=query),
]
response = await self.llm.chat(messages, temperature=0.7)
return response.content or "I couldn't generate a response."
async def _generate_follow_ups(self, query: str, answer: str) -> list[str]:
"""Generate follow-up questions.
Args:
query: Original query
answer: Generated answer
Returns:
List of follow-up questions
"""
prompt = format_prompt(
PromptNames.FOLLOW_UP_GEN,
user_query=query,
response_summary=answer[:500], # First 500 chars
topics=[],
gaps=[],
)
messages = [
Message(role=MessageRole.SYSTEM, content="Generate follow-up questions."),
Message(role=MessageRole.USER, content=prompt),
]
try:
response = await self.llm.chat(messages, temperature=0.7, max_tokens=500)
content = response.content or ""
# Parse JSON response
if "```json" in content:
content = content.split("```json")[1].split("```")[0]
data = json.loads(content)
return data.get("top_3_recommendations", [])[:3]
except Exception:
return []
async def _reflect_on_response(
self, query: str, answer: str, sources: list[Citation]
) -> float:
"""Reflect on response quality and return confidence score.
Args:
query: Original query
answer: Generated answer
sources: Source citations
Returns:
Confidence score (0.0 - 1.0)
"""
prompt = format_prompt(
PromptNames.REFLECTION,
user_query=query,
response=answer,
sources=[s.url for s in sources],
confidence=0.8,
)
messages = [
Message(role=MessageRole.SYSTEM, content="Evaluate response quality."),
Message(role=MessageRole.USER, content=prompt),
]
try:
response = await self.llm.chat(messages, temperature=0.3, max_tokens=500)
content = response.content or ""
if "```json" in content:
content = content.split("```json")[1].split("```")[0]
data = json.loads(content)
return data.get("overall_score", 8) / 10.0
except Exception:
return 0.8
def _extract_sources(self, state: AgentState) -> list[Citation]:
"""Extract sources from thought history.
Args:
state: Agent state with thought history
Returns:
List of citations
"""
sources = []
seen_urls = set()
for step in state.thought_history:
if step.observation:
try:
data = json.loads(step.observation)
results = data.get("results", [])
for r in results:
url = r.get("url", "")
if url and url not in seen_urls:
sources.append(
Citation(
title=r.get("title", ""),
url=url,
snippet=r.get("snippet", ""),
)
)
seen_urls.add(url)
except (json.JSONDecodeError, TypeError):
pass
return sources[:5] # Limit to 5 sources
def _summarize_reasoning(self, state: AgentState) -> str | None:
"""Summarize the reasoning process.
Args:
state: Agent state
Returns:
Reasoning summary or None
"""
if not state.thought_history:
return None
steps = []
for step in state.thought_history:
steps.append(f"Step {step.iteration}: {step.thought[:100]}...")
return "\n".join(steps)
def _format_search_results(self, results: list[dict[str, Any]]) -> str:
"""Format search results for context.
Args:
results: Search results
Returns:
Formatted string
"""
formatted = []
for i, r in enumerate(results, 1):
formatted.append(
f"[{i}] {r.get('title', 'No title')}\n"
f"URL: {r.get('url', '')}\n"
f"Content: {r.get('snippet', '')}\n"
)
return "\n".join(formatted)
async def query(
self,
question: str,
history: list[dict[str, str]] | None = None,
enable_search: bool = True,
max_sources: int = 5,
) -> AgentResponse:
"""Query the agent with a question.
This is an alias for the `ask()` method with a more flexible interface
that supports conversation history and search configuration.
Args:
question: User's question
history: Optional conversation history as list of {"role": str, "content": str}
enable_search: Whether to enable web search (default True)
max_sources: Maximum number of sources to return (default 5)
Returns:
AgentResponse with answer and sources
"""
# Build context from history
context = None
if history:
context = {"conversation_history": history}
# Call the main ask method
response = await self.ask(
query=question,
context=context,
use_react=enable_search, # Use ReACT when search is enabled
)
# Limit sources if needed
if len(response.sources) > max_sources:
response.sources = response.sources[:max_sources]
return response