Spaces:
Build error
Build error
| from langchain.memory import ConversationSummaryBufferMemory | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| from langchain_community.cache import SQLiteCache | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| from langchain.tools import BaseTool | |
| from langchain_core.messages import ToolCall | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| import langchain | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import List, Dict, Any, Optional | |
| import time | |
| import logging | |
| import asyncio | |
| logger = logging.getLogger(__name__) | |
| # Enable caching for faster responses | |
| langchain.llm_cache = SQLiteCache(database_path=".langchain.db") | |
| class CustomMetricsCallback(BaseCallbackHandler): | |
| """Custom callback for tracking metrics""" | |
| def __init__(self): | |
| self.token_count = 0 | |
| self.start_time = None | |
| def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs): | |
| """Track LLM start""" | |
| self.start_time = time.time() | |
| logger.info("LLM started") | |
| def on_llm_end(self, response, **kwargs): | |
| """Track LLM end and metrics""" | |
| if self.start_time: | |
| duration = time.time() - self.start_time | |
| logger.info(f"LLM completed in {duration:.2f}s") | |
| def on_llm_error(self, error: str, **kwargs): | |
| """Track LLM errors""" | |
| logger.error(f"LLM error: {error}") | |
| class ErrorRecoveryCallback(BaseCallbackHandler): | |
| """Callback for error recovery strategies""" | |
| def __init__(self): | |
| self.error_count = 0 | |
| self.last_error_time = None | |
| def on_llm_error(self, error: str, **kwargs): | |
| """Handle LLM errors with recovery strategies""" | |
| self.error_count += 1 | |
| self.last_error_time = time.time() | |
| if "context length" in error.lower(): | |
| logger.warning("Context length exceeded, reducing context") | |
| # TODO: Implement context reduction | |
| elif "timeout" in error.lower(): | |
| logger.warning("LLM timeout, switching to faster model") | |
| # TODO: Implement model switching | |
| else: | |
| logger.error(f"Unhandled LLM error: {error}") | |
| class ParallelToolExecutor: | |
| """Execute compatible tools in parallel""" | |
| def __init__(self, tools: List[BaseTool]): | |
| self.tools = {tool.name: tool for tool in tools} | |
| self.executor = ThreadPoolExecutor(max_workers=5) | |
| def _group_compatible_tools(self, tool_calls: List[ToolCall]) -> List[List[ToolCall]]: | |
| """Group tools by compatibility for parallel execution""" | |
| # Simple grouping: tools that don't share resources can run in parallel | |
| # For now, group by tool type | |
| groups = {} | |
| for tool_call in tool_calls: | |
| tool_name = tool_call.get('name', 'unknown') | |
| tool_type = tool_name.split('_')[0] # Extract tool type | |
| if tool_type not in groups: | |
| groups[tool_type] = [] | |
| groups[tool_type].append(tool_call) | |
| return list(groups.values()) | |
| async def execute_parallel( | |
| self, | |
| tool_calls: List[ToolCall] | |
| ) -> Dict[str, Any]: | |
| """Execute non-conflicting tools in parallel""" | |
| # Group tools by compatibility | |
| groups = self._group_compatible_tools(tool_calls) | |
| results = {} | |
| for group in groups: | |
| if len(group) == 1: | |
| # Single tool, execute normally | |
| tool_call = group[0] | |
| results[tool_call.get('id', 'unknown')] = await self._execute_single(tool_call) | |
| else: | |
| # Multiple compatible tools, execute in parallel | |
| futures = [] | |
| for tool_call in group: | |
| future = self.executor.submit( | |
| self._execute_single_sync, | |
| tool_call | |
| ) | |
| futures.append((tool_call.get('id', 'unknown'), future)) | |
| # Collect results | |
| for tool_id, future in futures: | |
| try: | |
| results[tool_id] = future.result(timeout=30) | |
| except Exception as e: | |
| logger.error(f"Tool execution failed for {tool_id}: {e}") | |
| results[tool_id] = f"Error: {str(e)}" | |
| return results | |
| async def _execute_single(self, tool_call: ToolCall) -> Any: | |
| """Execute a single tool call""" | |
| tool_name = tool_call.get('name', 'unknown') | |
| tool_args = tool_call.get('args', {}) | |
| if tool_name in self.tools: | |
| try: | |
| tool = self.tools[tool_name] | |
| result = await tool.ainvoke(tool_args) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Tool {tool_name} failed: {e}") | |
| return f"Error: {str(e)}" | |
| else: | |
| return f"Tool {tool_name} not found" | |
| def _execute_single_sync(self, tool_call: ToolCall) -> Any: | |
| """Execute a single tool call synchronously""" | |
| tool_name = tool_call.get('name', 'unknown') | |
| tool_args = tool_call.get('args', {}) | |
| if tool_name in self.tools: | |
| try: | |
| tool = self.tools[tool_name] | |
| result = tool.invoke(tool_args) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Tool {tool_name} failed: {e}") | |
| return f"Error: {str(e)}" | |
| else: | |
| return f"Tool {tool_name} not found" | |
| class EnhancedLangChainAgent: | |
| """Optimized LangChain agent with advanced features""" | |
| def __init__(self, llm, tools: List[BaseTool]): | |
| self.llm = llm | |
| self.tools = tools | |
| self.tool_executor = ParallelToolExecutor(tools) | |
| # Use summary buffer memory for long conversations | |
| self.memory = ConversationSummaryBufferMemory( | |
| llm=self.llm, | |
| max_token_limit=2000, | |
| return_messages=True | |
| ) | |
| # Custom prompt optimization | |
| self.system_prompt = PromptTemplate( | |
| input_variables=["context", "question", "tools"], | |
| template="""You are an expert GAIA agent. | |
| Context from previous conversation: | |
| {context} | |
| Available tools: | |
| {tools} | |
| Question: {question} | |
| Instructions: | |
| 1. Analyze the question type | |
| 2. Select appropriate tools | |
| 3. Execute step-by-step | |
| 4. Verify results | |
| 5. Provide ONLY the final answer | |
| Answer:""" | |
| ) | |
| # Create optimized chain | |
| self.chain = self.create_optimized_chain() | |
| def create_optimized_chain(self): | |
| """Create chain with streaming and callbacks""" | |
| return LLMChain( | |
| llm=self.llm, | |
| prompt=self.system_prompt, | |
| memory=self.memory, | |
| callbacks=[ | |
| StreamingStdOutCallbackHandler(), | |
| CustomMetricsCallback(), | |
| ErrorRecoveryCallback() | |
| ] | |
| ) | |
| async def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |
| """Run the enhanced agent""" | |
| try: | |
| # Add tools to context | |
| tools_info = "\n".join([f"- {tool.name}: {tool.description}" for tool in self.tools]) | |
| inputs["tools"] = tools_info | |
| # Run the chain | |
| result = await self.chain.arun(inputs) | |
| return { | |
| "result": result, | |
| "memory_summary": self.get_memory_summary() | |
| } | |
| except Exception as e: | |
| logger.error(f"Enhanced agent error: {e}") | |
| return { | |
| "error": str(e), | |
| "memory_summary": self.get_memory_summary() | |
| } | |
| def get_memory_summary(self) -> str: | |
| """Get memory summary for debugging""" | |
| try: | |
| return self.memory.moving_summary_buffer | |
| except: | |
| return "Memory summary not available" | |
| def initialize_enhanced_agent(llm, tools: List[BaseTool]): | |
| """Initialize enhanced LangChain agent""" | |
| return EnhancedLangChainAgent(llm, tools) | |
| # Export for compatibility | |
| enhanced_agent = initialize_enhanced_agent |