Final_Assignment_Template / crewai_agent.py
Aditya0619's picture
Update crewai_agent.py
7a7bd49 verified
raw
history blame
11.2 kB
import os
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
from langchain.agents import AgentType, initialize_agent, Tool
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_core.tools import tool
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
# Load environment variables
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
HUGGINGFACE_API_TOKEN = os.getenv('HUGGINGFACE_API_TOKEN')
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
@tool
def calculator_tool(operation: str, a: float, b: float) -> str:
"""Perform basic mathematical operations: add, subtract, multiply, divide, modulus
Args:
operation: The operation to perform (add, subtract, multiply, divide, modulus)
a: First number
b: Second number
Returns:
Result of the mathematical operation
"""
try:
if operation == "add":
return str(a + b)
elif operation == "subtract":
return str(a - b)
elif operation == "multiply":
return str(a * b)
elif operation == "divide":
if b == 0:
return "Error: Cannot divide by zero"
return str(a / b)
elif operation == "modulus":
return str(a % b)
else:
return "Error: Unsupported operation. Use: add, subtract, multiply, divide, modulus"
except Exception as e:
return f"Error: {str(e)}"
@tool
def wikipedia_search_tool(query: str) -> str:
"""Search Wikipedia for information on any topic
Args:
query: The search query for Wikipedia
Returns:
Formatted Wikipedia search results
"""
try:
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
formatted_results = "\n\n---\n\n".join([
f'Source: {doc.metadata["source"]}\nPage: {doc.metadata.get("page", "")}\n\nContent:\n{doc.page_content[:2000]}...'
for doc in search_docs
])
return formatted_results
except Exception as e:
return f"Error searching Wikipedia: {str(e)}"
@tool
def web_search_tool(query: str) -> str:
"""Search the web for current information using Tavily
Args:
query: The search query for web search
Returns:
Formatted web search results
"""
try:
if not TAVILY_API_KEY:
return "Error: TAVILY_API_KEY not found in environment variables"
search_results = TavilySearchResults(max_results=3, api_key=TAVILY_API_KEY).invoke(query)
formatted_results = "\n\n---\n\n".join([
f'Source: {result.get("url", "")}\n\nContent:\n{result.get("content", "")}'
for result in search_results
])
return formatted_results
except Exception as e:
return f"Error searching web: {str(e)}"
@tool
def arxiv_search_tool(query: str) -> str:
"""Search ArXiv for academic papers and research
Args:
query: The search query for ArXiv
Returns:
Formatted ArXiv search results
"""
try:
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_results = "\n\n---\n\n".join([
f'Source: {doc.metadata["source"]}\nTitle: {doc.metadata.get("Title", "")}\n\nContent:\n{doc.page_content[:1500]}...'
for doc in search_docs
])
return formatted_results
except Exception as e:
return f"Error searching ArXiv: {str(e)}"
class LangChainAgent:
"""Multi-purpose LangChain agent with various capabilities."""
def __init__(self, provider: str = "google"):
"""Initialize the LangChain agent with specified LLM provider."""
self.provider = provider
self.llm = self._get_llm(provider)
self.tools = self._initialize_tools()
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
self.agent = self._create_agent()
def _get_llm(self, provider: str):
"""Get the specified LLM."""
if provider == "google":
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY not found in environment variables")
return ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=0,
max_tokens=2048,
google_api_key=GOOGLE_API_KEY
)
elif provider == "huggingface":
if not HUGGINGFACE_API_TOKEN:
raise ValueError("HUGGINGFACE_API_TOKEN not found in environment variables")
return ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id="microsoft/DialoGPT-medium",
temperature=0,
max_length=2048,
huggingfacehub_api_token=HUGGINGFACE_API_TOKEN
),
)
else:
raise ValueError("Invalid provider. Choose 'google' or 'huggingface'.")
def _initialize_tools(self) -> List[Tool]:
"""Initialize all available tools."""
return [
calculator_tool,
wikipedia_search_tool,
web_search_tool,
arxiv_search_tool,
]
def _create_agent(self):
"""Create the LangChain agent with tools."""
try:
return initialize_agent(
tools=self.tools,
llm=self.llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
memory=self.memory,
verbose=True,
handle_parsing_errors=True,
max_iterations=3,
early_stopping_method="generate"
)
except Exception as e:
print(f"Error creating agent: {e}")
# Return a simple agent without tools as fallback
return None
def _determine_approach(self, question: str) -> str:
"""Determine the best approach for answering the question."""
question_lower = question.lower()
# Check for mathematical operations
math_keywords = ['calculate', 'compute', 'add', 'subtract', 'multiply', 'divide', 'math', 'equation', '+', '-', '*', '/', '%']
if any(keyword in question_lower for keyword in math_keywords):
return 'calculation'
# Check for research-related queries
research_keywords = ['search', 'find', 'research', 'information', 'what is', 'who is', 'when', 'where', 'how', 'why']
if any(keyword in question_lower for keyword in research_keywords):
return 'research'
# Check for academic/scientific queries
academic_keywords = ['paper', 'study', 'research', 'academic', 'scientific', 'arxiv', 'journal']
if any(keyword in question_lower for keyword in academic_keywords):
return 'academic'
return 'general'
def __call__(self, question: str) -> str:
"""Process a question and return an answer."""
try:
print(f"Processing question: {question[:100]}...")
# If agent initialization failed, use direct LLM
if self.agent is None:
print("Agent not available, using direct LLM response")
try:
response = self.llm.invoke([HumanMessage(content=question)])
return response.content
except Exception as llm_error:
return f"Error: Unable to process question. {str(llm_error)}"
# Determine the best approach for this question
approach = self._determine_approach(question)
print(f"Selected approach: {approach}")
# Create a comprehensive prompt based on the approach
if approach == 'calculation':
enhanced_question = f"""
Solve this mathematical problem step by step:
{question}
Use the calculator tool if needed for complex calculations. Show your work clearly.
"""
elif approach == 'research':
enhanced_question = f"""
Research and provide comprehensive information about:
{question}
Use Wikipedia search and web search tools to gather current and accurate information.
Cite your sources and provide detailed explanations.
"""
elif approach == 'academic':
enhanced_question = f"""
Find academic and scientific information about:
{question}
Use ArXiv search and other research tools to find relevant academic papers and studies.
Provide citations and summarize key findings.
"""
else:
enhanced_question = f"""
Provide a comprehensive answer to:
{question}
Use appropriate tools as needed (calculator, search tools) to provide accurate information.
"""
# Use the agent to process the question
result = self.agent.run(enhanced_question)
print(f"Generated answer: {str(result)[:200]}...")
return str(result)
except Exception as e:
error_msg = f"Error processing question: {str(e)}"
print(error_msg)
# Provide a fallback response
try:
# Try a simple LLM response without tools
fallback_result = self.llm.invoke([HumanMessage(content=question)])
return fallback_result.content
except Exception as fallback_error:
return f"Error: Unable to process question. {str(e)}"
def reset_memory(self):
"""Reset the conversation memory."""
self.memory.clear()
# Test function
def test_langchain_agent():
"""Test the LangChain agent with sample questions."""
agent = LangChainAgent(provider="google")
test_questions = [
"What is 25 * 34?",
"Who was Albert Einstein?",
"Search for recent developments in artificial intelligence",
"What is the theory of relativity?"
]
for question in test_questions:
print(f"\nQuestion: {question}")
answer = agent(question)
print(f"Answer: {answer}")
print("-" * 50)
agent.reset_memory() # Reset memory between questions for testing
if __name__ == "__main__":
test_langchain_agent()