Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| """Main agent class that ties everything together.""" | |
| import json | |
| from typing import Any, AsyncIterator | |
| from src.agent.controller import AgentController | |
| from src.agent.models import AgentResponse, AgentState, Citation | |
| from src.llm import LLMClient, Message, MessageRole, create_llm_client | |
| from src.llm.prompts import format_prompt, get_system_prompt, PromptNames | |
| from src.tools.base import create_default_registry, ToolRegistry | |
| from src.utils.config import settings | |
| from src.utils.exceptions import AgentError | |
| from src.utils.logging import get_logger | |
| logger = get_logger(__name__) | |
| class AskTheWebAgent: | |
| """Main Ask-the-Web Agent class.""" | |
| def __init__( | |
| self, | |
| llm_client: LLMClient | None = None, | |
| tool_registry: ToolRegistry | None = None, | |
| ): | |
| """Initialize the agent. | |
| Args: | |
| llm_client: Optional LLM client (creates default if not provided) | |
| tool_registry: Optional tool registry (creates default if not provided) | |
| """ | |
| self.llm = llm_client or create_llm_client() | |
| self.tools = tool_registry or create_default_registry() | |
| self.controller = AgentController(self.llm, self.tools) | |
| self.system_prompt = get_system_prompt() | |
| async def ask( | |
| self, | |
| query: str, | |
| context: dict[str, Any] | None = None, | |
| use_react: bool = True, | |
| ) -> AgentResponse: | |
| """Ask the agent a question. | |
| Args: | |
| query: User's question | |
| context: Optional conversation context | |
| use_react: Whether to use ReACT reasoning (default True) | |
| Returns: | |
| AgentResponse with answer and sources | |
| """ | |
| logger.info(f"Processing query: {query}") | |
| # Initialize state | |
| state = AgentState(query=query) | |
| if context: | |
| state.working_memory.update(context) | |
| try: | |
| # 1. Parse intent | |
| state.intent = await self.controller.parse_intent(query) | |
| logger.info(f"Parsed intent: {state.intent.intent_type.value}") | |
| # 2. Plan workflow | |
| state.plan = await self.controller.plan_workflow(query, state.intent) | |
| logger.info(f"Selected strategy: {state.plan.strategy.value}") | |
| # 3. Execute based on strategy | |
| if use_react and state.intent.requires_web_search: | |
| # Use ReACT loop for complex queries | |
| answer, thought_history = await self.controller.run_react_loop(state) | |
| state.thought_history = thought_history | |
| # Extract sources from tool results | |
| sources = self._extract_sources(state) | |
| # Generate follow-up questions | |
| follow_ups = await self._generate_follow_ups(query, answer) | |
| else: | |
| # Simple execution for direct answers | |
| answer = await self._direct_answer(query, state) | |
| sources = [] | |
| follow_ups = [] | |
| # 4. Reflect on response quality (optional) | |
| confidence = 0.8 | |
| if settings.enable_reflection: | |
| confidence = await self._reflect_on_response(query, answer, sources) | |
| return AgentResponse( | |
| answer=answer, | |
| confidence=confidence, | |
| sources=sources, | |
| reasoning_summary=self._summarize_reasoning(state), | |
| follow_up_questions=follow_ups, | |
| metadata={ | |
| "intent": state.intent.intent_type.value if state.intent else None, | |
| "strategy": state.plan.strategy.value if state.plan else None, | |
| "iterations": len(state.thought_history), | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Agent error: {e}") | |
| raise AgentError(f"Failed to process query: {e}") from e | |
| async def ask_stream( | |
| self, | |
| query: str, | |
| context: dict[str, Any] | None = None, | |
| ) -> AsyncIterator[str]: | |
| """Stream a response to a question. | |
| Args: | |
| query: User's question | |
| context: Optional conversation context | |
| Yields: | |
| Response chunks as they become available | |
| """ | |
| logger.info(f"Streaming query: {query}") | |
| # For streaming, we'll do a simplified flow | |
| state = AgentState(query=query) | |
| # Parse intent first | |
| state.intent = await self.controller.parse_intent(query) | |
| if state.intent.requires_web_search: | |
| # Do search first | |
| result = await self.tools.execute( | |
| "web_search", query=query, num_results=settings.max_search_results | |
| ) | |
| if result.success: | |
| search_results = result.data.get("results", []) | |
| context_text = self._format_search_results(search_results) | |
| else: | |
| context_text = "" | |
| else: | |
| context_text = "" | |
| # Stream the response | |
| messages = [ | |
| Message(role=MessageRole.SYSTEM, content=self.system_prompt), | |
| Message( | |
| role=MessageRole.USER, | |
| content=f"Question: {query}\n\nContext from web search:\n{context_text}\n\nProvide a comprehensive answer with citations.", | |
| ), | |
| ] | |
| async for chunk in self.llm.chat_stream(messages): | |
| yield chunk | |
| async def _direct_answer(self, query: str, state: AgentState) -> str: | |
| """Generate a direct answer without web search. | |
| Args: | |
| query: User's question | |
| state: Current state | |
| Returns: | |
| Answer string | |
| """ | |
| messages = [ | |
| Message(role=MessageRole.SYSTEM, content=self.system_prompt), | |
| Message(role=MessageRole.USER, content=query), | |
| ] | |
| response = await self.llm.chat(messages, temperature=0.7) | |
| return response.content or "I couldn't generate a response." | |
| async def _generate_follow_ups(self, query: str, answer: str) -> list[str]: | |
| """Generate follow-up questions. | |
| Args: | |
| query: Original query | |
| answer: Generated answer | |
| Returns: | |
| List of follow-up questions | |
| """ | |
| prompt = format_prompt( | |
| PromptNames.FOLLOW_UP_GEN, | |
| user_query=query, | |
| response_summary=answer[:500], # First 500 chars | |
| topics=[], | |
| gaps=[], | |
| ) | |
| messages = [ | |
| Message(role=MessageRole.SYSTEM, content="Generate follow-up questions."), | |
| Message(role=MessageRole.USER, content=prompt), | |
| ] | |
| try: | |
| response = await self.llm.chat(messages, temperature=0.7, max_tokens=500) | |
| content = response.content or "" | |
| # Parse JSON response | |
| if "```json" in content: | |
| content = content.split("```json")[1].split("```")[0] | |
| data = json.loads(content) | |
| return data.get("top_3_recommendations", [])[:3] | |
| except Exception: | |
| return [] | |
| async def _reflect_on_response( | |
| self, query: str, answer: str, sources: list[Citation] | |
| ) -> float: | |
| """Reflect on response quality and return confidence score. | |
| Args: | |
| query: Original query | |
| answer: Generated answer | |
| sources: Source citations | |
| Returns: | |
| Confidence score (0.0 - 1.0) | |
| """ | |
| prompt = format_prompt( | |
| PromptNames.REFLECTION, | |
| user_query=query, | |
| response=answer, | |
| sources=[s.url for s in sources], | |
| confidence=0.8, | |
| ) | |
| messages = [ | |
| Message(role=MessageRole.SYSTEM, content="Evaluate response quality."), | |
| Message(role=MessageRole.USER, content=prompt), | |
| ] | |
| try: | |
| response = await self.llm.chat(messages, temperature=0.3, max_tokens=500) | |
| content = response.content or "" | |
| if "```json" in content: | |
| content = content.split("```json")[1].split("```")[0] | |
| data = json.loads(content) | |
| return data.get("overall_score", 8) / 10.0 | |
| except Exception: | |
| return 0.8 | |
| def _extract_sources(self, state: AgentState) -> list[Citation]: | |
| """Extract sources from thought history. | |
| Args: | |
| state: Agent state with thought history | |
| Returns: | |
| List of citations | |
| """ | |
| sources = [] | |
| seen_urls = set() | |
| for step in state.thought_history: | |
| if step.observation: | |
| try: | |
| data = json.loads(step.observation) | |
| results = data.get("results", []) | |
| for r in results: | |
| url = r.get("url", "") | |
| if url and url not in seen_urls: | |
| sources.append( | |
| Citation( | |
| title=r.get("title", ""), | |
| url=url, | |
| snippet=r.get("snippet", ""), | |
| ) | |
| ) | |
| seen_urls.add(url) | |
| except (json.JSONDecodeError, TypeError): | |
| pass | |
| return sources[:5] # Limit to 5 sources | |
| def _summarize_reasoning(self, state: AgentState) -> str | None: | |
| """Summarize the reasoning process. | |
| Args: | |
| state: Agent state | |
| Returns: | |
| Reasoning summary or None | |
| """ | |
| if not state.thought_history: | |
| return None | |
| steps = [] | |
| for step in state.thought_history: | |
| steps.append(f"Step {step.iteration}: {step.thought[:100]}...") | |
| return "\n".join(steps) | |
| def _format_search_results(self, results: list[dict[str, Any]]) -> str: | |
| """Format search results for context. | |
| Args: | |
| results: Search results | |
| Returns: | |
| Formatted string | |
| """ | |
| formatted = [] | |
| for i, r in enumerate(results, 1): | |
| formatted.append( | |
| f"[{i}] {r.get('title', 'No title')}\n" | |
| f"URL: {r.get('url', '')}\n" | |
| f"Content: {r.get('snippet', '')}\n" | |
| ) | |
| return "\n".join(formatted) | |
| async def query( | |
| self, | |
| question: str, | |
| history: list[dict[str, str]] | None = None, | |
| enable_search: bool = True, | |
| max_sources: int = 5, | |
| ) -> AgentResponse: | |
| """Query the agent with a question. | |
| This is an alias for the `ask()` method with a more flexible interface | |
| that supports conversation history and search configuration. | |
| Args: | |
| question: User's question | |
| history: Optional conversation history as list of {"role": str, "content": str} | |
| enable_search: Whether to enable web search (default True) | |
| max_sources: Maximum number of sources to return (default 5) | |
| Returns: | |
| AgentResponse with answer and sources | |
| """ | |
| # Build context from history | |
| context = None | |
| if history: | |
| context = {"conversation_history": history} | |
| # Call the main ask method | |
| response = await self.ask( | |
| query=question, | |
| context=context, | |
| use_react=enable_search, # Use ReACT when search is enabled | |
| ) | |
| # Limit sources if needed | |
| if len(response.sources) > max_sources: | |
| response.sources = response.sources[:max_sources] | |
| return response | |