Final_Assignment_Template / crewai_agent.py
Aditya0619's picture
Update crewai_agent.py
04968b6 verified
import os
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
from langchain.agents import AgentType, initialize_agent, Tool
from langchain.memory import ConversationBufferWindowMemory, ConversationSummaryBufferMemory
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_core.tools import tool
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
# Load environment variables
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
HUGGINGFACE_API_TOKEN = os.getenv('HUGGINGFACE_API_TOKEN')
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
@tool
def calculator_tool(expression: str) -> str:
"""Perform mathematical calculations and evaluate expressions
Args:
expression: A mathematical expression to evaluate (e.g., "2+2", "25*34", "sqrt(16)", "sin(0.5)")
Returns:
Result of the mathematical expression
"""
try:
import math
import re
# Clean the expression and make it safe
expression = expression.strip()
# Replace common mathematical functions
expression = expression.replace('sqrt', 'math.sqrt')
expression = expression.replace('sin', 'math.sin')
expression = expression.replace('cos', 'math.cos')
expression = expression.replace('tan', 'math.tan')
expression = expression.replace('log', 'math.log')
expression = expression.replace('ln', 'math.log')
expression = expression.replace('log10', 'math.log10')
expression = expression.replace('pi', 'math.pi')
expression = expression.replace('e', 'math.e')
expression = expression.replace('^', '**') # Python uses ** for power
expression = expression.replace('pow', '**')
# More comprehensive regex for mathematical expressions
safe_pattern = r'^[0-9+\-*/.() mathsqrtsincolgtanpienpow]+$'
if re.match(safe_pattern, expression.replace(' ', '')):
# Create a safe namespace for eval
safe_dict = {
"__builtins__": {},
"math": math,
"abs": abs,
"round": round,
"min": min,
"max": max
}
result = eval(expression, safe_dict)
return str(result)
else:
return "Error: Invalid characters in expression. Use only numbers and basic math operations."
except ZeroDivisionError:
return "Error: Cannot divide by zero"
except Exception as e:
return f"Error: {str(e)}"
@tool
def wikipedia_search_tool(query: str) -> str:
"""Search Wikipedia for information on any topic
Args:
query: The search query for Wikipedia
Returns:
Formatted Wikipedia search results
"""
try:
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
formatted_results = "\n\n---\n\n".join([
f'Source: {doc.metadata["source"]}\nPage: {doc.metadata.get("page", "")}\n\nContent:\n{doc.page_content[:2000]}...'
for doc in search_docs
])
return formatted_results
except Exception as e:
return f"Error searching Wikipedia: {str(e)}"
@tool
def web_search_tool(query: str) -> str:
"""Search the web for current information using Tavily
Args:
query: The search query for web search
Returns:
Formatted web search results
"""
try:
if not TAVILY_API_KEY:
return "Error: TAVILY_API_KEY not found in environment variables"
search_results = TavilySearchResults(max_results=3, api_key=TAVILY_API_KEY).invoke(query)
formatted_results = "\n\n---\n\n".join([
f'Source: {result.get("url", "")}\n\nContent:\n{result.get("content", "")}'
for result in search_results
])
return formatted_results
except Exception as e:
return f"Error searching web: {str(e)}"
@tool
def arxiv_search_tool(query: str) -> str:
"""Search ArXiv for academic papers and research
Args:
query: The search query for ArXiv
Returns:
Formatted ArXiv search results
"""
try:
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_results = "\n\n---\n\n".join([
f'Source: {doc.metadata["source"]}\nTitle: {doc.metadata.get("Title", "")}\n\nContent:\n{doc.page_content[:1500]}...'
for doc in search_docs
])
return formatted_results
except Exception as e:
return f"Error searching ArXiv: {str(e)}"
class LangChainAgent:
"""Multi-purpose LangChain agent with various capabilities."""
def __init__(self, provider: str = "groq"):
"""Initialize the LangChain agent with specified LLM provider."""
self.provider = provider
self.llm = self._get_llm(provider)
self.tools = self._initialize_tools()
# Use ConversationSummaryBufferMemory for better long-term memory management
self.memory = ConversationSummaryBufferMemory(
llm=self.llm,
memory_key="chat_history",
return_messages=True,
max_token_limit=2000, # Limit memory to prevent token overflow
moving_summary_buffer="The human and AI are having a conversation about various topics."
)
self.agent = self._create_agent()
def _get_llm(self, provider: str):
"""Get the specified LLM."""
if provider == "groq":
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not found in environment variables")
return ChatGroq(
model="llama-3.3-70b-versatile", # Latest Llama model available on Groq
temperature=0.1,
max_tokens=8192, # Increased token limit for better responses
api_key=GROQ_API_KEY,
streaming=False
)
elif provider == "google":
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY not found in environment variables")
return ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=0,
max_tokens=2048,
google_api_key=GOOGLE_API_KEY
)
elif provider == "huggingface":
if not HUGGINGFACE_API_TOKEN:
raise ValueError("HUGGINGFACE_API_TOKEN not found in environment variables")
return ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id="microsoft/DialoGPT-medium",
temperature=0,
max_length=2048,
huggingfacehub_api_token=HUGGINGFACE_API_TOKEN
),
)
else:
raise ValueError("Invalid provider. Choose 'groq', 'google' or 'huggingface'.")
def _initialize_tools(self) -> List[Tool]:
"""Initialize all available tools."""
return [
calculator_tool,
wikipedia_search_tool,
web_search_tool,
arxiv_search_tool,
]
def _create_agent(self):
"""Create the LangChain agent with tools."""
try:
return initialize_agent(
tools=self.tools,
llm=self.llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
memory=self.memory,
verbose=True,
handle_parsing_errors=True,
max_iterations=4,
early_stopping_method="generate",
return_intermediate_steps=False
)
except Exception as e:
print(f"Error creating agent: {e}")
# Return a simple agent without tools as fallback
return None
def _determine_approach(self, question: str) -> str:
"""Determine the best approach for answering the question."""
question_lower = question.lower()
# Check for mathematical operations
math_keywords = ['calculate', 'compute', 'add', 'subtract', 'multiply', 'divide', 'math', 'equation', '+', '-', '*', '/', '%']
if any(keyword in question_lower for keyword in math_keywords):
return 'calculation'
# Check for research-related queries
research_keywords = ['search', 'find', 'research', 'information', 'what is', 'who is', 'when', 'where', 'how', 'why']
if any(keyword in question_lower for keyword in research_keywords):
return 'research'
# Check for academic/scientific queries
academic_keywords = ['paper', 'study', 'research', 'academic', 'scientific', 'arxiv', 'journal']
if any(keyword in question_lower for keyword in academic_keywords):
return 'academic'
return 'general'
def __call__(self, question: str) -> str:
"""Process a question and return an answer."""
try:
print(f"Processing question: {question[:100]}...")
# If agent initialization failed, use direct LLM
if self.agent is None:
print("Agent not available, using direct LLM response")
try:
response = self.llm.invoke([HumanMessage(content=question)])
return response.content
except Exception as llm_error:
return f"Error: Unable to process question. {str(llm_error)}"
# Determine the best approach for this question
approach = self._determine_approach(question)
print(f"Selected approach: {approach}")
# Create a comprehensive prompt based on the approach
if approach == 'calculation':
enhanced_question = f"""
You are a mathematical assistant. Solve this problem step by step:
{question}
IMPORTANT: Use the calculator_tool for ALL mathematical calculations, even simple ones.
Examples:
- For "25 * 34", use: calculator_tool("25 * 34")
- For "sqrt(16)", use: calculator_tool("sqrt(16)")
- For "2 + 2", use: calculator_tool("2 + 2")
Always show your work and use the tools provided.
"""
elif approach == 'research':
enhanced_question = f"""
You are a research assistant. Provide comprehensive information about:
{question}
IMPORTANT: Use the appropriate search tools to gather information:
- wikipedia_search_tool("your search query") for general knowledge
- web_search_tool("your search query") for current information
- arxiv_search_tool("your search query") for academic papers
Always cite your sources and provide detailed explanations.
"""
elif approach == 'academic':
enhanced_question = f"""
You are an academic research assistant. Find scholarly information about:
{question}
IMPORTANT: Use research tools to find information:
- arxiv_search_tool("your search query") for academic papers
- wikipedia_search_tool("your search query") for background information
Provide citations and summarize key findings.
"""
else:
enhanced_question = f"""
You are a helpful assistant. Answer this question comprehensively:
{question}
IMPORTANT: Use the appropriate tools as needed:
- calculator_tool("expression") for mathematical calculations
- wikipedia_search_tool("query") for general information
- web_search_tool("query") for current information
- arxiv_search_tool("query") for academic research
Always use tools when they can help provide better answers.
"""
# Use the agent to process the question
result = self.agent.run(enhanced_question)
print(f"Generated answer: {str(result)[:200]}...")
return str(result)
except Exception as e:
error_msg = f"Error processing question: {str(e)}"
print(error_msg)
# Provide a fallback response
try:
# Try a simple LLM response without tools
fallback_result = self.llm.invoke([HumanMessage(content=question)])
return fallback_result.content
except Exception as fallback_error:
return f"Error: Unable to process question. {str(e)}"
def reset_memory(self):
"""Reset the conversation memory."""
self.memory.clear()
# Test function
def test_langchain_agent():
"""Test the LangChain agent with sample questions."""
print("Testing LangChain Agent with Groq Llama...")
agent = LangChainAgent(provider="groq")
test_questions = [
"What is 25 * 34 + 100?",
"Who was Albert Einstein and what were his major contributions?",
"Search for recent developments in artificial intelligence",
"What is quantum computing?"
]
for question in test_questions:
print(f"\nQuestion: {question}")
print("-" * 50)
answer = agent(question)
print(f"Answer: {answer}")
print("=" * 80)
agent.reset_memory() # Reset memory between questions for testing
if __name__ == "__main__":
test_langchain_agent()