| | import os |
| | from typing import List, Dict, Any, Optional |
| | from dotenv import load_dotenv |
| |
|
| | |
| | load_dotenv() |
| |
|
| | from langchain.agents import AgentType, initialize_agent, Tool |
| | from langchain.memory import ConversationBufferWindowMemory, ConversationSummaryBufferMemory |
| | from langchain_core.messages import BaseMessage, HumanMessage, AIMessage |
| | from langchain_google_genai import ChatGoogleGenerativeAI |
| | from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint |
| | from langchain_groq import ChatGroq |
| | from langchain_community.tools.tavily_search import TavilySearchResults |
| | from langchain_community.document_loaders import WikipediaLoader, ArxivLoader |
| | from langchain_core.tools import tool |
| | from langchain.prompts import PromptTemplate |
| | from langchain.chains import LLMChain |
| |
|
| | |
| | GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') |
| | HUGGINGFACE_API_TOKEN = os.getenv('HUGGINGFACE_API_TOKEN') |
| | GROQ_API_KEY = os.getenv('GROQ_API_KEY') |
| | TAVILY_API_KEY = os.getenv('TAVILY_API_KEY') |
| |
|
| |
|
| | @tool |
| | def calculator_tool(expression: str) -> str: |
| | """Perform mathematical calculations and evaluate expressions |
| | |
| | Args: |
| | expression: A mathematical expression to evaluate (e.g., "2+2", "25*34", "sqrt(16)", "sin(0.5)") |
| | |
| | Returns: |
| | Result of the mathematical expression |
| | """ |
| | try: |
| | import math |
| | import re |
| | |
| | |
| | expression = expression.strip() |
| | |
| | |
| | expression = expression.replace('sqrt', 'math.sqrt') |
| | expression = expression.replace('sin', 'math.sin') |
| | expression = expression.replace('cos', 'math.cos') |
| | expression = expression.replace('tan', 'math.tan') |
| | expression = expression.replace('log', 'math.log') |
| | expression = expression.replace('ln', 'math.log') |
| | expression = expression.replace('log10', 'math.log10') |
| | expression = expression.replace('pi', 'math.pi') |
| | expression = expression.replace('e', 'math.e') |
| | expression = expression.replace('^', '**') |
| | expression = expression.replace('pow', '**') |
| | |
| | |
| | safe_pattern = r'^[0-9+\-*/.() mathsqrtsincolgtanpienpow]+$' |
| | if re.match(safe_pattern, expression.replace(' ', '')): |
| | |
| | safe_dict = { |
| | "__builtins__": {}, |
| | "math": math, |
| | "abs": abs, |
| | "round": round, |
| | "min": min, |
| | "max": max |
| | } |
| | result = eval(expression, safe_dict) |
| | return str(result) |
| | else: |
| | return "Error: Invalid characters in expression. Use only numbers and basic math operations." |
| | |
| | except ZeroDivisionError: |
| | return "Error: Cannot divide by zero" |
| | except Exception as e: |
| | return f"Error: {str(e)}" |
| |
|
| | @tool |
| | def wikipedia_search_tool(query: str) -> str: |
| | """Search Wikipedia for information on any topic |
| | |
| | Args: |
| | query: The search query for Wikipedia |
| | |
| | Returns: |
| | Formatted Wikipedia search results |
| | """ |
| | try: |
| | search_docs = WikipediaLoader(query=query, load_max_docs=2).load() |
| | formatted_results = "\n\n---\n\n".join([ |
| | f'Source: {doc.metadata["source"]}\nPage: {doc.metadata.get("page", "")}\n\nContent:\n{doc.page_content[:2000]}...' |
| | for doc in search_docs |
| | ]) |
| | return formatted_results |
| | except Exception as e: |
| | return f"Error searching Wikipedia: {str(e)}" |
| |
|
| | @tool |
| | def web_search_tool(query: str) -> str: |
| | """Search the web for current information using Tavily |
| | |
| | Args: |
| | query: The search query for web search |
| | |
| | Returns: |
| | Formatted web search results |
| | """ |
| | try: |
| | if not TAVILY_API_KEY: |
| | return "Error: TAVILY_API_KEY not found in environment variables" |
| | |
| | search_results = TavilySearchResults(max_results=3, api_key=TAVILY_API_KEY).invoke(query) |
| | formatted_results = "\n\n---\n\n".join([ |
| | f'Source: {result.get("url", "")}\n\nContent:\n{result.get("content", "")}' |
| | for result in search_results |
| | ]) |
| | return formatted_results |
| | except Exception as e: |
| | return f"Error searching web: {str(e)}" |
| |
|
| | @tool |
| | def arxiv_search_tool(query: str) -> str: |
| | """Search ArXiv for academic papers and research |
| | |
| | Args: |
| | query: The search query for ArXiv |
| | |
| | Returns: |
| | Formatted ArXiv search results |
| | """ |
| | try: |
| | search_docs = ArxivLoader(query=query, load_max_docs=3).load() |
| | formatted_results = "\n\n---\n\n".join([ |
| | f'Source: {doc.metadata["source"]}\nTitle: {doc.metadata.get("Title", "")}\n\nContent:\n{doc.page_content[:1500]}...' |
| | for doc in search_docs |
| | ]) |
| | return formatted_results |
| | except Exception as e: |
| | return f"Error searching ArXiv: {str(e)}" |
| |
|
| | class LangChainAgent: |
| | """Multi-purpose LangChain agent with various capabilities.""" |
| | |
| | def __init__(self, provider: str = "groq"): |
| | """Initialize the LangChain agent with specified LLM provider.""" |
| | self.provider = provider |
| | self.llm = self._get_llm(provider) |
| | self.tools = self._initialize_tools() |
| | |
| | self.memory = ConversationSummaryBufferMemory( |
| | llm=self.llm, |
| | memory_key="chat_history", |
| | return_messages=True, |
| | max_token_limit=2000, |
| | moving_summary_buffer="The human and AI are having a conversation about various topics." |
| | ) |
| | self.agent = self._create_agent() |
| | |
| | def _get_llm(self, provider: str): |
| | """Get the specified LLM.""" |
| | if provider == "groq": |
| | if not GROQ_API_KEY: |
| | raise ValueError("GROQ_API_KEY not found in environment variables") |
| | return ChatGroq( |
| | model="llama-3.3-70b-versatile", |
| | temperature=0.1, |
| | max_tokens=8192, |
| | api_key=GROQ_API_KEY, |
| | streaming=False |
| | ) |
| | elif provider == "google": |
| | if not GOOGLE_API_KEY: |
| | raise ValueError("GOOGLE_API_KEY not found in environment variables") |
| | return ChatGoogleGenerativeAI( |
| | model="gemini-1.5-flash", |
| | temperature=0, |
| | max_tokens=2048, |
| | google_api_key=GOOGLE_API_KEY |
| | ) |
| | elif provider == "huggingface": |
| | if not HUGGINGFACE_API_TOKEN: |
| | raise ValueError("HUGGINGFACE_API_TOKEN not found in environment variables") |
| | return ChatHuggingFace( |
| | llm=HuggingFaceEndpoint( |
| | repo_id="microsoft/DialoGPT-medium", |
| | temperature=0, |
| | max_length=2048, |
| | huggingfacehub_api_token=HUGGINGFACE_API_TOKEN |
| | ), |
| | ) |
| | else: |
| | raise ValueError("Invalid provider. Choose 'groq', 'google' or 'huggingface'.") |
| | |
| | def _initialize_tools(self) -> List[Tool]: |
| | """Initialize all available tools.""" |
| | return [ |
| | calculator_tool, |
| | wikipedia_search_tool, |
| | web_search_tool, |
| | arxiv_search_tool, |
| | ] |
| | |
| | def _create_agent(self): |
| | """Create the LangChain agent with tools.""" |
| | try: |
| | return initialize_agent( |
| | tools=self.tools, |
| | llm=self.llm, |
| | agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, |
| | memory=self.memory, |
| | verbose=True, |
| | handle_parsing_errors=True, |
| | max_iterations=4, |
| | early_stopping_method="generate", |
| | return_intermediate_steps=False |
| | ) |
| | except Exception as e: |
| | print(f"Error creating agent: {e}") |
| | |
| | return None |
| | |
| | def _determine_approach(self, question: str) -> str: |
| | """Determine the best approach for answering the question.""" |
| | question_lower = question.lower() |
| | |
| | |
| | math_keywords = ['calculate', 'compute', 'add', 'subtract', 'multiply', 'divide', 'math', 'equation', '+', '-', '*', '/', '%'] |
| | if any(keyword in question_lower for keyword in math_keywords): |
| | return 'calculation' |
| | |
| | |
| | research_keywords = ['search', 'find', 'research', 'information', 'what is', 'who is', 'when', 'where', 'how', 'why'] |
| | if any(keyword in question_lower for keyword in research_keywords): |
| | return 'research' |
| | |
| | |
| | academic_keywords = ['paper', 'study', 'research', 'academic', 'scientific', 'arxiv', 'journal'] |
| | if any(keyword in question_lower for keyword in academic_keywords): |
| | return 'academic' |
| | |
| | return 'general' |
| | |
| | def __call__(self, question: str) -> str: |
| | """Process a question and return an answer.""" |
| | try: |
| | print(f"Processing question: {question[:100]}...") |
| | |
| | |
| | if self.agent is None: |
| | print("Agent not available, using direct LLM response") |
| | try: |
| | response = self.llm.invoke([HumanMessage(content=question)]) |
| | return response.content |
| | except Exception as llm_error: |
| | return f"Error: Unable to process question. {str(llm_error)}" |
| | |
| | |
| | approach = self._determine_approach(question) |
| | print(f"Selected approach: {approach}") |
| | |
| | |
| | if approach == 'calculation': |
| | enhanced_question = f""" |
| | You are a mathematical assistant. Solve this problem step by step: |
| | |
| | {question} |
| | |
| | IMPORTANT: Use the calculator_tool for ALL mathematical calculations, even simple ones. |
| | Examples: |
| | - For "25 * 34", use: calculator_tool("25 * 34") |
| | - For "sqrt(16)", use: calculator_tool("sqrt(16)") |
| | - For "2 + 2", use: calculator_tool("2 + 2") |
| | |
| | Always show your work and use the tools provided. |
| | """ |
| | elif approach == 'research': |
| | enhanced_question = f""" |
| | You are a research assistant. Provide comprehensive information about: |
| | |
| | {question} |
| | |
| | IMPORTANT: Use the appropriate search tools to gather information: |
| | - wikipedia_search_tool("your search query") for general knowledge |
| | - web_search_tool("your search query") for current information |
| | - arxiv_search_tool("your search query") for academic papers |
| | |
| | Always cite your sources and provide detailed explanations. |
| | """ |
| | elif approach == 'academic': |
| | enhanced_question = f""" |
| | You are an academic research assistant. Find scholarly information about: |
| | |
| | {question} |
| | |
| | IMPORTANT: Use research tools to find information: |
| | - arxiv_search_tool("your search query") for academic papers |
| | - wikipedia_search_tool("your search query") for background information |
| | |
| | Provide citations and summarize key findings. |
| | """ |
| | else: |
| | enhanced_question = f""" |
| | You are a helpful assistant. Answer this question comprehensively: |
| | |
| | {question} |
| | |
| | IMPORTANT: Use the appropriate tools as needed: |
| | - calculator_tool("expression") for mathematical calculations |
| | - wikipedia_search_tool("query") for general information |
| | - web_search_tool("query") for current information |
| | - arxiv_search_tool("query") for academic research |
| | |
| | Always use tools when they can help provide better answers. |
| | """ |
| | |
| | |
| | result = self.agent.run(enhanced_question) |
| | |
| | print(f"Generated answer: {str(result)[:200]}...") |
| | return str(result) |
| | |
| | except Exception as e: |
| | error_msg = f"Error processing question: {str(e)}" |
| | print(error_msg) |
| | |
| | try: |
| | |
| | fallback_result = self.llm.invoke([HumanMessage(content=question)]) |
| | return fallback_result.content |
| | except Exception as fallback_error: |
| | return f"Error: Unable to process question. {str(e)}" |
| |
|
| | def reset_memory(self): |
| | """Reset the conversation memory.""" |
| | self.memory.clear() |
| |
|
| | |
| | def test_langchain_agent(): |
| | """Test the LangChain agent with sample questions.""" |
| | print("Testing LangChain Agent with Groq Llama...") |
| | agent = LangChainAgent(provider="groq") |
| | |
| | test_questions = [ |
| | "What is 25 * 34 + 100?", |
| | "Who was Albert Einstein and what were his major contributions?", |
| | "Search for recent developments in artificial intelligence", |
| | "What is quantum computing?" |
| | ] |
| | |
| | for question in test_questions: |
| | print(f"\nQuestion: {question}") |
| | print("-" * 50) |
| | answer = agent(question) |
| | print(f"Answer: {answer}") |
| | print("=" * 80) |
| | agent.reset_memory() |
| |
|
| | if __name__ == "__main__": |
| | test_langchain_agent() |
| |
|