Spaces:
Sleeping
Sleeping
| """Base agent class with common functionality.""" | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from typing import List, Dict, Any, Optional, Callable | |
| from openai import OpenAI | |
| from src.core.config import get_settings | |
| from src.memory.short_term_memory import ShortTermMemory | |
| from src.memory.long_term_memory import LongTermMemory | |
| from src.planning.react_planner import ReActPlanner | |
| from src.planning.cot_planner import CoTPlanner | |
| logger = logging.getLogger(__name__) | |
| class BaseAgent(ABC): | |
| """Base class for all agents.""" | |
| def __init__( | |
| self, | |
| name: str, | |
| description: str, | |
| tools: Optional[List[Dict[str, Any]]] = None, | |
| use_memory: bool = True, | |
| use_planning: bool = False, | |
| planning_type: str = "react", # "react" or "cot" | |
| ): | |
| """ | |
| Initialize base agent. | |
| Args: | |
| name: Agent name | |
| description: Agent description | |
| tools: List of available tools | |
| use_memory: Whether to use memory | |
| use_planning: Whether to use planning | |
| planning_type: Type of planning ("react" or "cot") | |
| """ | |
| self.name = name | |
| self.description = description | |
| self.settings = get_settings() | |
| # Initialize OpenAI client | |
| self.client = OpenAI(**self.settings.get_openai_client_kwargs()) | |
| self.model = self.settings.openai_model | |
| # Initialize memory | |
| self.use_memory = use_memory | |
| self.short_term_memory: Optional[ShortTermMemory] = None | |
| self.long_term_memory: Optional[LongTermMemory] = None | |
| if use_memory: | |
| self.short_term_memory = ShortTermMemory() | |
| self.long_term_memory = LongTermMemory() | |
| # Initialize planning | |
| self.use_planning = use_planning | |
| self.planning_type = planning_type | |
| self.planner: Optional[ReActPlanner | CoTPlanner] = None | |
| if use_planning: | |
| if planning_type == "react": | |
| self.planner = ReActPlanner(tools=tools or []) | |
| elif planning_type == "cot": | |
| self.planner = CoTPlanner() | |
| else: | |
| logger.warning(f"Unknown planning type: {planning_type}") | |
| # Tools | |
| self.tools = tools or [] | |
| self.tool_functions: Dict[str, Callable] = {} | |
| def add_tool(self, tool: Dict[str, Any], tool_function: Callable) -> None: | |
| """ | |
| Add a tool to the agent. | |
| Args: | |
| tool: Tool schema | |
| tool_function: Function to execute the tool | |
| """ | |
| self.tools.append(tool) | |
| self.tool_functions[tool["name"]] = tool_function | |
| if self.planner and isinstance(self.planner, ReActPlanner): | |
| self.planner.add_tool(tool) | |
| async def process( | |
| self, | |
| query: str, | |
| session_id: Optional[str] = None, | |
| context: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Process a query using the agent. | |
| Args: | |
| query: User query | |
| session_id: Optional session ID for memory | |
| context: Optional additional context | |
| Returns: | |
| Response dictionary | |
| """ | |
| try: | |
| # Add user message to memory | |
| if self.short_term_memory: | |
| self.short_term_memory.add_message("user", query) | |
| # Load long-term memory if available | |
| long_term_context = "" | |
| if self.long_term_memory and session_id: | |
| memories = self.long_term_memory.search_memories(query, session_id, n_results=3) | |
| if memories: | |
| long_term_context = "\n".join([ | |
| m["content"] for m in memories | |
| ]) | |
| # Combine contexts | |
| full_context = self._build_context(context, long_term_context) | |
| # Use planning if enabled | |
| if self.use_planning and self.planner: | |
| response = await self._process_with_planning(query, full_context, session_id) | |
| else: | |
| response = await self._process_direct(query, full_context, session_id) | |
| # Add assistant response to memory | |
| if self.short_term_memory and "answer" in response: | |
| self.short_term_memory.add_message("assistant", response["answer"]) | |
| # Store in long-term memory | |
| if self.long_term_memory and session_id: | |
| messages = self.short_term_memory.get_messages() if self.short_term_memory else [] | |
| self.long_term_memory.store_conversation(session_id, messages) | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error processing query in {self.name}: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "agent": self.name, | |
| } | |
| async def _process_direct( | |
| self, | |
| query: str, | |
| context: str, | |
| session_id: Optional[str], | |
| ) -> Dict[str, Any]: | |
| """Process query directly without planning.""" | |
| # Build messages | |
| messages = [] | |
| if context: | |
| messages.append({ | |
| "role": "system", | |
| "content": f"{self.description}\n\nContext: {context}", | |
| }) | |
| else: | |
| messages.append({ | |
| "role": "system", | |
| "content": self.description, | |
| }) | |
| # Add conversation history | |
| if self.short_term_memory: | |
| history = self.short_term_memory.get_messages(format_for_llm=True) | |
| messages.extend(history[-5:]) # Last 5 messages | |
| else: | |
| messages.append({ | |
| "role": "user", | |
| "content": query, | |
| }) | |
| # Call LLM | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=0.7, | |
| ) | |
| answer = response.choices[0].message.content | |
| return { | |
| "success": True, | |
| "answer": answer, | |
| "agent": self.name, | |
| "model": self.model, | |
| } | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "quota" in error_msg.lower() or "429" in error_msg: | |
| logger.error(f"OpenAI API quota exceeded: {e}") | |
| raise Exception("OpenAI API quota exceeded. Please check your billing and plan details.") | |
| elif "api key" in error_msg.lower() or "401" in error_msg: | |
| logger.error(f"Invalid OpenAI API key: {e}") | |
| raise Exception("Invalid OpenAI API key. Please check your .env file.") | |
| else: | |
| logger.error(f"Error calling LLM: {e}") | |
| raise | |
| async def _process_with_planning( | |
| self, | |
| query: str, | |
| context: str, | |
| session_id: Optional[str], | |
| ) -> Dict[str, Any]: | |
| """Process query using planning.""" | |
| if not self.planner: | |
| return await self._process_direct(query, context, session_id) | |
| # Create sync LLM call function (planner expects sync) | |
| def llm_call(prompt: str) -> str: | |
| messages = [ | |
| {"role": "system", "content": self.description}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=0.7, | |
| ) | |
| return response.choices[0].message.content | |
| # Generate plan (planner methods are sync) | |
| if isinstance(self.planner, ReActPlanner): | |
| plan = self.planner.plan( | |
| query=query, | |
| context=context, | |
| llm_call=llm_call, | |
| ) | |
| else: # CoT planner | |
| plan = self.planner.plan( | |
| query=query, | |
| context=context, | |
| llm_call=llm_call, | |
| ) | |
| # Extract final answer | |
| if isinstance(self.planner, ReActPlanner): | |
| answer = plan.get("final_answer", "I couldn't find a complete answer.") | |
| else: | |
| answer = plan.get("conclusion", "I couldn't find a complete answer.") | |
| return { | |
| "success": True, | |
| "answer": answer, | |
| "agent": self.name, | |
| "plan": plan, | |
| "model": self.model, | |
| } | |
| def _build_context( | |
| self, | |
| additional_context: Optional[str], | |
| long_term_context: str, | |
| ) -> str: | |
| """Build full context string.""" | |
| parts = [] | |
| if long_term_context: | |
| parts.append(f"Relevant past conversations:\n{long_term_context}") | |
| if additional_context: | |
| parts.append(f"Additional context:\n{additional_context}") | |
| return "\n\n".join(parts) | |
| async def _execute_tool( | |
| self, | |
| tool_name: str, | |
| **kwargs, | |
| ) -> Any: | |
| """Execute a tool (supports both sync and async tools).""" | |
| if tool_name not in self.tool_functions: | |
| raise ValueError(f"Tool '{tool_name}' not found") | |
| tool_func = self.tool_functions[tool_name] | |
| # Check if tool is async | |
| import asyncio | |
| if asyncio.iscoroutinefunction(tool_func): | |
| return await tool_func(**kwargs) | |
| else: | |
| return tool_func(**kwargs) | |
| async def retrieve_context(self, query: str) -> str: | |
| """ | |
| Retrieve relevant context for the query. | |
| Args: | |
| query: User query | |
| Returns: | |
| Context string | |
| """ | |
| pass | |
| def get_status(self) -> Dict[str, Any]: | |
| """Get agent status.""" | |
| return { | |
| "name": self.name, | |
| "description": self.description, | |
| "tools": [t["name"] for t in self.tools], | |
| "memory_enabled": self.use_memory, | |
| "planning_enabled": self.use_planning, | |
| "planning_type": self.planning_type if self.use_planning else None, | |
| } | |