Spaces:
Runtime error
Runtime error
| import os | |
| from typing import List, Dict, Any, Optional | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| from langchain.agents import AgentType, initialize_agent, Tool | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
| from langchain_core.tools import tool | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| # Load environment variables | |
| GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') | |
| HUGGINGFACE_API_TOKEN = os.getenv('HUGGINGFACE_API_TOKEN') | |
| TAVILY_API_KEY = os.getenv('TAVILY_API_KEY') | |
| def calculator_tool(operation: str, a: float, b: float) -> str: | |
| """Perform basic mathematical operations: add, subtract, multiply, divide, modulus | |
| Args: | |
| operation: The operation to perform (add, subtract, multiply, divide, modulus) | |
| a: First number | |
| b: Second number | |
| Returns: | |
| Result of the mathematical operation | |
| """ | |
| try: | |
| if operation == "add": | |
| return str(a + b) | |
| elif operation == "subtract": | |
| return str(a - b) | |
| elif operation == "multiply": | |
| return str(a * b) | |
| elif operation == "divide": | |
| if b == 0: | |
| return "Error: Cannot divide by zero" | |
| return str(a / b) | |
| elif operation == "modulus": | |
| return str(a % b) | |
| else: | |
| return "Error: Unsupported operation. Use: add, subtract, multiply, divide, modulus" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def wikipedia_search_tool(query: str) -> str: | |
| """Search Wikipedia for information on any topic | |
| Args: | |
| query: The search query for Wikipedia | |
| Returns: | |
| Formatted Wikipedia search results | |
| """ | |
| try: | |
| search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| formatted_results = "\n\n---\n\n".join([ | |
| f'Source: {doc.metadata["source"]}\nPage: {doc.metadata.get("page", "")}\n\nContent:\n{doc.page_content[:2000]}...' | |
| for doc in search_docs | |
| ]) | |
| return formatted_results | |
| except Exception as e: | |
| return f"Error searching Wikipedia: {str(e)}" | |
| def web_search_tool(query: str) -> str: | |
| """Search the web for current information using Tavily | |
| Args: | |
| query: The search query for web search | |
| Returns: | |
| Formatted web search results | |
| """ | |
| try: | |
| if not TAVILY_API_KEY: | |
| return "Error: TAVILY_API_KEY not found in environment variables" | |
| search_results = TavilySearchResults(max_results=3, api_key=TAVILY_API_KEY).invoke(query) | |
| formatted_results = "\n\n---\n\n".join([ | |
| f'Source: {result.get("url", "")}\n\nContent:\n{result.get("content", "")}' | |
| for result in search_results | |
| ]) | |
| return formatted_results | |
| except Exception as e: | |
| return f"Error searching web: {str(e)}" | |
| def arxiv_search_tool(query: str) -> str: | |
| """Search ArXiv for academic papers and research | |
| Args: | |
| query: The search query for ArXiv | |
| Returns: | |
| Formatted ArXiv search results | |
| """ | |
| try: | |
| search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| formatted_results = "\n\n---\n\n".join([ | |
| f'Source: {doc.metadata["source"]}\nTitle: {doc.metadata.get("Title", "")}\n\nContent:\n{doc.page_content[:1500]}...' | |
| for doc in search_docs | |
| ]) | |
| return formatted_results | |
| except Exception as e: | |
| return f"Error searching ArXiv: {str(e)}" | |
| class LangChainAgent: | |
| """Multi-purpose LangChain agent with various capabilities.""" | |
| def __init__(self, provider: str = "google"): | |
| """Initialize the LangChain agent with specified LLM provider.""" | |
| self.provider = provider | |
| self.llm = self._get_llm(provider) | |
| self.tools = self._initialize_tools() | |
| self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| self.agent = self._create_agent() | |
| def _get_llm(self, provider: str): | |
| """Get the specified LLM.""" | |
| if provider == "google": | |
| if not GOOGLE_API_KEY: | |
| raise ValueError("GOOGLE_API_KEY not found in environment variables") | |
| return ChatGoogleGenerativeAI( | |
| model="gemini-1.5-flash", | |
| temperature=0, | |
| max_tokens=2048, | |
| google_api_key=GOOGLE_API_KEY | |
| ) | |
| elif provider == "huggingface": | |
| if not HUGGINGFACE_API_TOKEN: | |
| raise ValueError("HUGGINGFACE_API_TOKEN not found in environment variables") | |
| return ChatHuggingFace( | |
| llm=HuggingFaceEndpoint( | |
| repo_id="microsoft/DialoGPT-medium", | |
| temperature=0, | |
| max_length=2048, | |
| huggingfacehub_api_token=HUGGINGFACE_API_TOKEN | |
| ), | |
| ) | |
| else: | |
| raise ValueError("Invalid provider. Choose 'google' or 'huggingface'.") | |
| def _initialize_tools(self) -> List[Tool]: | |
| """Initialize all available tools.""" | |
| return [ | |
| calculator_tool, | |
| wikipedia_search_tool, | |
| web_search_tool, | |
| arxiv_search_tool, | |
| ] | |
| def _create_agent(self): | |
| """Create the LangChain agent with tools.""" | |
| try: | |
| return initialize_agent( | |
| tools=self.tools, | |
| llm=self.llm, | |
| agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| memory=self.memory, | |
| verbose=True, | |
| handle_parsing_errors=True, | |
| max_iterations=3, | |
| early_stopping_method="generate" | |
| ) | |
| except Exception as e: | |
| print(f"Error creating agent: {e}") | |
| # Return a simple agent without tools as fallback | |
| return None | |
| def _determine_approach(self, question: str) -> str: | |
| """Determine the best approach for answering the question.""" | |
| question_lower = question.lower() | |
| # Check for mathematical operations | |
| math_keywords = ['calculate', 'compute', 'add', 'subtract', 'multiply', 'divide', 'math', 'equation', '+', '-', '*', '/', '%'] | |
| if any(keyword in question_lower for keyword in math_keywords): | |
| return 'calculation' | |
| # Check for research-related queries | |
| research_keywords = ['search', 'find', 'research', 'information', 'what is', 'who is', 'when', 'where', 'how', 'why'] | |
| if any(keyword in question_lower for keyword in research_keywords): | |
| return 'research' | |
| # Check for academic/scientific queries | |
| academic_keywords = ['paper', 'study', 'research', 'academic', 'scientific', 'arxiv', 'journal'] | |
| if any(keyword in question_lower for keyword in academic_keywords): | |
| return 'academic' | |
| return 'general' | |
| def __call__(self, question: str) -> str: | |
| """Process a question and return an answer.""" | |
| try: | |
| print(f"Processing question: {question[:100]}...") | |
| # If agent initialization failed, use direct LLM | |
| if self.agent is None: | |
| print("Agent not available, using direct LLM response") | |
| try: | |
| response = self.llm.invoke([HumanMessage(content=question)]) | |
| return response.content | |
| except Exception as llm_error: | |
| return f"Error: Unable to process question. {str(llm_error)}" | |
| # Determine the best approach for this question | |
| approach = self._determine_approach(question) | |
| print(f"Selected approach: {approach}") | |
| # Create a comprehensive prompt based on the approach | |
| if approach == 'calculation': | |
| enhanced_question = f""" | |
| Solve this mathematical problem step by step: | |
| {question} | |
| Use the calculator tool if needed for complex calculations. Show your work clearly. | |
| """ | |
| elif approach == 'research': | |
| enhanced_question = f""" | |
| Research and provide comprehensive information about: | |
| {question} | |
| Use Wikipedia search and web search tools to gather current and accurate information. | |
| Cite your sources and provide detailed explanations. | |
| """ | |
| elif approach == 'academic': | |
| enhanced_question = f""" | |
| Find academic and scientific information about: | |
| {question} | |
| Use ArXiv search and other research tools to find relevant academic papers and studies. | |
| Provide citations and summarize key findings. | |
| """ | |
| else: | |
| enhanced_question = f""" | |
| Provide a comprehensive answer to: | |
| {question} | |
| Use appropriate tools as needed (calculator, search tools) to provide accurate information. | |
| """ | |
| # Use the agent to process the question | |
| result = self.agent.run(enhanced_question) | |
| print(f"Generated answer: {str(result)[:200]}...") | |
| return str(result) | |
| except Exception as e: | |
| error_msg = f"Error processing question: {str(e)}" | |
| print(error_msg) | |
| # Provide a fallback response | |
| try: | |
| # Try a simple LLM response without tools | |
| fallback_result = self.llm.invoke([HumanMessage(content=question)]) | |
| return fallback_result.content | |
| except Exception as fallback_error: | |
| return f"Error: Unable to process question. {str(e)}" | |
| def reset_memory(self): | |
| """Reset the conversation memory.""" | |
| self.memory.clear() | |
| # Test function | |
| def test_langchain_agent(): | |
| """Test the LangChain agent with sample questions.""" | |
| agent = LangChainAgent(provider="google") | |
| test_questions = [ | |
| "What is 25 * 34?", | |
| "Who was Albert Einstein?", | |
| "Search for recent developments in artificial intelligence", | |
| "What is the theory of relativity?" | |
| ] | |
| for question in test_questions: | |
| print(f"\nQuestion: {question}") | |
| answer = agent(question) | |
| print(f"Answer: {answer}") | |
| print("-" * 50) | |
| agent.reset_memory() # Reset memory between questions for testing | |
| if __name__ == "__main__": | |
| test_langchain_agent() | |