Aditya0619 commited on
Commit
f0250b1
·
verified ·
1 Parent(s): f2387db

Update crewai_agent.py

Browse files
Files changed (1) hide show
  1. crewai_agent.py +178 -172
crewai_agent.py CHANGED
@@ -1,187 +1,182 @@
1
  import os
2
- from typing import List, Dict, Any
3
  from dotenv import load_dotenv
4
 
5
- from crewai import Agent, Task, Crew, Process
6
- from crewai.tools import BaseTool
 
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
11
- from pydantic import BaseModel, Field
 
 
12
 
13
  # Load environment variables
14
  load_dotenv()
15
 
16
- class CalculatorTool(BaseTool):
17
- """Mathematical calculator tool for basic arithmetic operations."""
18
- name: str = "calculator"
19
- description: str = "Perform basic mathematical operations: add, subtract, multiply, divide, modulus"
20
 
21
- def _run(self, operation: str, a: float, b: float) -> float:
22
- """Execute mathematical operations."""
23
- try:
24
- if operation == "add":
25
- return a + b
26
- elif operation == "subtract":
27
- return a - b
28
- elif operation == "multiply":
29
- return a * b
30
- elif operation == "divide":
31
- if b == 0:
32
- return "Error: Cannot divide by zero"
33
- return a / b
34
- elif operation == "modulus":
35
- return a % b
36
- else:
37
- return "Error: Unsupported operation"
38
- except Exception as e:
39
- return f"Error: {str(e)}"
 
 
 
 
 
 
40
 
41
- class WikipediaSearchTool(BaseTool):
42
- """Wikipedia search tool for research."""
43
- name: str = "wikipedia_search"
44
- description: str = "Search Wikipedia for information on any topic"
45
 
46
- def _run(self, query: str) -> str:
47
- """Search Wikipedia and return formatted results."""
48
- try:
49
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
50
- formatted_results = "\n\n---\n\n".join([
51
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
52
- for doc in search_docs
53
- ])
54
- return formatted_results
55
- except Exception as e:
56
- return f"Error searching Wikipedia: {str(e)}"
 
 
 
 
57
 
58
- class WebSearchTool(BaseTool):
59
- """Web search tool using Tavily."""
60
- name: str = "web_search"
61
- description: str = "Search the web for current information using Tavily"
62
 
63
- def _run(self, query: str) -> str:
64
- """Search the web and return formatted results."""
65
- try:
66
- search_results = TavilySearchResults(max_results=3).invoke(query)
67
- formatted_results = "\n\n---\n\n".join([
68
- f'<Document source="{result.get("url", "")}">\n{result.get("content", "")}\n</Document>'
69
- for result in search_results
70
- ])
71
- return formatted_results
72
- except Exception as e:
73
- return f"Error searching web: {str(e)}"
 
 
 
 
74
 
75
- class ArxivSearchTool(BaseTool):
76
- """ArXiv search tool for academic papers."""
77
- name: str = "arxiv_search"
78
- description: str = "Search ArXiv for academic papers and research"
79
 
80
- def _run(self, query: str) -> str:
81
- """Search ArXiv and return formatted results."""
82
- try:
83
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
84
- formatted_results = "\n\n---\n\n".join([
85
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
86
- for doc in search_docs
87
- ])
88
- return formatted_results
89
- except Exception as e:
90
- return f"Error searching ArXiv: {str(e)}"
 
 
 
 
91
 
92
- class CrewAIAgent:
93
- """Multi-purpose CrewAI agent with various capabilities."""
94
 
95
  def __init__(self, provider: str = "google"):
96
- """Initialize the CrewAI agent with specified LLM provider."""
97
  self.provider = provider
98
  self.llm = self._get_llm(provider)
99
  self.tools = self._initialize_tools()
100
- self.agents = self._create_agents()
 
101
 
102
  def _get_llm(self, provider: str):
103
  """Get the specified LLM."""
104
  if provider == "google":
105
- return ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
 
 
 
 
106
  elif provider == "huggingface":
107
  return ChatHuggingFace(
108
  llm=HuggingFaceEndpoint(
109
  repo_id="microsoft/DialoGPT-medium",
110
  temperature=0,
 
111
  ),
112
  )
113
  else:
114
  raise ValueError("Invalid provider. Choose 'google' or 'huggingface'.")
115
 
116
- def _initialize_tools(self) -> List[BaseTool]:
117
  """Initialize all available tools."""
118
  return [
119
- CalculatorTool(),
120
- WikipediaSearchTool(),
121
- WebSearchTool(),
122
- ArxivSearchTool(),
123
  ]
124
 
125
- def _create_agents(self) -> Dict[str, Agent]:
126
- """Create specialized agents for different tasks."""
127
-
128
- # Research Agent
129
- research_agent = Agent(
130
- role='Research Specialist',
131
- goal='Gather comprehensive and accurate information from multiple sources',
132
- backstory="""You are an expert researcher with access to Wikipedia, ArXiv, and web search tools.
133
- You excel at finding relevant, current, and reliable information on any topic.""",
134
- tools=[tool for tool in self.tools if 'search' in tool.name],
135
- llm=self.llm,
136
- verbose=True,
137
- allow_delegation=False
138
- )
139
-
140
- # Calculation Agent
141
- calculation_agent = Agent(
142
- role='Mathematical Analyst',
143
- goal='Perform accurate mathematical calculations and analysis',
144
- backstory="""You are a mathematical expert capable of performing various calculations
145
- and explaining mathematical concepts clearly.""",
146
- tools=[tool for tool in self.tools if 'calculator' in tool.name],
147
- llm=self.llm,
148
- verbose=True,
149
- allow_delegation=False
150
- )
151
-
152
- # General Assistant Agent
153
- general_agent = Agent(
154
- role='General Assistant',
155
- goal='Provide comprehensive answers by coordinating with specialized agents',
156
- backstory="""You are a versatile AI assistant that can handle various types of questions
157
- by leveraging specialized tools and knowledge.""",
158
  tools=self.tools,
159
  llm=self.llm,
 
 
160
  verbose=True,
161
- allow_delegation=True
 
 
162
  )
163
-
164
- return {
165
- 'research': research_agent,
166
- 'calculation': calculation_agent,
167
- 'general': general_agent
168
- }
169
 
170
- def _determine_agent_type(self, question: str) -> str:
171
- """Determine which agent is best suited for the question."""
172
  question_lower = question.lower()
173
 
174
  # Check for mathematical operations
175
- math_keywords = ['calculate', 'compute', 'add', 'subtract', 'multiply', 'divide', 'math', 'equation']
176
  if any(keyword in question_lower for keyword in math_keywords):
177
  return 'calculation'
178
 
179
  # Check for research-related queries
180
- research_keywords = ['search', 'find', 'research', 'information', 'what is', 'who is', 'when', 'where', 'how']
181
  if any(keyword in question_lower for keyword in research_keywords):
182
  return 'research'
183
 
184
- # Default to general agent
 
 
 
 
185
  return 'general'
186
 
187
  def __call__(self, question: str) -> str:
@@ -189,61 +184,71 @@ class CrewAIAgent:
189
  try:
190
  print(f"Processing question: {question[:100]}...")
191
 
192
- # Determine the best agent for this question
193
- agent_type = self._determine_agent_type(question)
194
- selected_agent = self.agents[agent_type]
195
-
196
- print(f"Selected agent: {agent_type}")
197
 
198
- # Create a task for the selected agent
199
- task = Task(
200
- description=f"""
201
- Answer the following question comprehensively and accurately:
202
 
203
- Question: {question}
204
 
205
- Guidelines:
206
- - Use appropriate tools when needed
207
- - Provide detailed and helpful responses
208
- - Cite sources when using external information
209
- - Show calculations step by step for mathematical problems
210
- - Be clear and concise in your explanations
211
- """,
212
- agent=selected_agent,
213
- expected_output="A comprehensive and accurate answer to the user's question"
214
- )
215
-
216
- # Create and execute the crew
217
- crew = Crew(
218
- agents=[selected_agent],
219
- tasks=[task],
220
- process=Process.sequential,
221
- verbose=True
222
- )
223
-
224
- # Execute the task
225
- result = crew.kickoff()
226
-
227
- # Extract the answer from the result
228
- if hasattr(result, 'raw'):
229
- answer = result.raw
230
- elif isinstance(result, str):
231
- answer = result
232
  else:
233
- answer = str(result)
 
 
 
 
 
 
234
 
235
- print(f"Generated answer: {answer[:200]}...")
236
- return answer
 
 
 
237
 
238
  except Exception as e:
239
  error_msg = f"Error processing question: {str(e)}"
240
  print(error_msg)
241
- return error_msg
 
 
 
 
 
 
 
 
 
 
242
 
243
  # Test function
244
- def test_crewai_agent():
245
- """Test the CrewAI agent with sample questions."""
246
- agent = CrewAIAgent(provider="google")
247
 
248
  test_questions = [
249
  "What is 25 * 34?",
@@ -257,6 +262,7 @@ def test_crewai_agent():
257
  answer = agent(question)
258
  print(f"Answer: {answer}")
259
  print("-" * 50)
 
260
 
261
  if __name__ == "__main__":
262
- test_crewai_agent()
 
1
  import os
2
+ from typing import List, Dict, Any, Optional
3
  from dotenv import load_dotenv
4
 
5
+ from langchain.agents import AgentType, initialize_agent, Tool
6
+ from langchain.memory import ConversationBufferMemory
7
+ from langchain.schema import BaseMessage, HumanMessage, AIMessage
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
12
+ from langchain.tools import tool
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain.chains import LLMChain
15
 
16
  # Load environment variables
17
  load_dotenv()
18
 
19
+ @tool
20
+ def calculator_tool(operation: str, a: float, b: float) -> str:
21
+ """Perform basic mathematical operations: add, subtract, multiply, divide, modulus
 
22
 
23
+ Args:
24
+ operation: The operation to perform (add, subtract, multiply, divide, modulus)
25
+ a: First number
26
+ b: Second number
27
+
28
+ Returns:
29
+ Result of the mathematical operation
30
+ """
31
+ try:
32
+ if operation == "add":
33
+ return str(a + b)
34
+ elif operation == "subtract":
35
+ return str(a - b)
36
+ elif operation == "multiply":
37
+ return str(a * b)
38
+ elif operation == "divide":
39
+ if b == 0:
40
+ return "Error: Cannot divide by zero"
41
+ return str(a / b)
42
+ elif operation == "modulus":
43
+ return str(a % b)
44
+ else:
45
+ return "Error: Unsupported operation. Use: add, subtract, multiply, divide, modulus"
46
+ except Exception as e:
47
+ return f"Error: {str(e)}"
48
 
49
+ @tool
50
+ def wikipedia_search_tool(query: str) -> str:
51
+ """Search Wikipedia for information on any topic
 
52
 
53
+ Args:
54
+ query: The search query for Wikipedia
55
+
56
+ Returns:
57
+ Formatted Wikipedia search results
58
+ """
59
+ try:
60
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
61
+ formatted_results = "\n\n---\n\n".join([
62
+ f'Source: {doc.metadata["source"]}\nPage: {doc.metadata.get("page", "")}\n\nContent:\n{doc.page_content[:2000]}...'
63
+ for doc in search_docs
64
+ ])
65
+ return formatted_results
66
+ except Exception as e:
67
+ return f"Error searching Wikipedia: {str(e)}"
68
 
69
+ @tool
70
+ def web_search_tool(query: str) -> str:
71
+ """Search the web for current information using Tavily
 
72
 
73
+ Args:
74
+ query: The search query for web search
75
+
76
+ Returns:
77
+ Formatted web search results
78
+ """
79
+ try:
80
+ search_results = TavilySearchResults(max_results=3).invoke(query)
81
+ formatted_results = "\n\n---\n\n".join([
82
+ f'Source: {result.get("url", "")}\n\nContent:\n{result.get("content", "")}'
83
+ for result in search_results
84
+ ])
85
+ return formatted_results
86
+ except Exception as e:
87
+ return f"Error searching web: {str(e)}"
88
 
89
+ @tool
90
+ def arxiv_search_tool(query: str) -> str:
91
+ """Search ArXiv for academic papers and research
 
92
 
93
+ Args:
94
+ query: The search query for ArXiv
95
+
96
+ Returns:
97
+ Formatted ArXiv search results
98
+ """
99
+ try:
100
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
101
+ formatted_results = "\n\n---\n\n".join([
102
+ f'Source: {doc.metadata["source"]}\nTitle: {doc.metadata.get("Title", "")}\n\nContent:\n{doc.page_content[:1500]}...'
103
+ for doc in search_docs
104
+ ])
105
+ return formatted_results
106
+ except Exception as e:
107
+ return f"Error searching ArXiv: {str(e)}"
108
 
109
+ class LangChainAgent:
110
+ """Multi-purpose LangChain agent with various capabilities."""
111
 
112
  def __init__(self, provider: str = "google"):
113
+ """Initialize the LangChain agent with specified LLM provider."""
114
  self.provider = provider
115
  self.llm = self._get_llm(provider)
116
  self.tools = self._initialize_tools()
117
+ self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
118
+ self.agent = self._create_agent()
119
 
120
  def _get_llm(self, provider: str):
121
  """Get the specified LLM."""
122
  if provider == "google":
123
+ return ChatGoogleGenerativeAI(
124
+ model="gemini-2.0-flash",
125
+ temperature=0,
126
+ max_tokens=2048
127
+ )
128
  elif provider == "huggingface":
129
  return ChatHuggingFace(
130
  llm=HuggingFaceEndpoint(
131
  repo_id="microsoft/DialoGPT-medium",
132
  temperature=0,
133
+ max_length=2048,
134
  ),
135
  )
136
  else:
137
  raise ValueError("Invalid provider. Choose 'google' or 'huggingface'.")
138
 
139
+ def _initialize_tools(self) -> List[Tool]:
140
  """Initialize all available tools."""
141
  return [
142
+ calculator_tool,
143
+ wikipedia_search_tool,
144
+ web_search_tool,
145
+ arxiv_search_tool,
146
  ]
147
 
148
+ def _create_agent(self):
149
+ """Create the LangChain agent with tools."""
150
+ return initialize_agent(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  tools=self.tools,
152
  llm=self.llm,
153
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
154
+ memory=self.memory,
155
  verbose=True,
156
+ handle_parsing_errors=True,
157
+ max_iterations=5,
158
+ early_stopping_method="generate"
159
  )
 
 
 
 
 
 
160
 
161
+ def _determine_approach(self, question: str) -> str:
162
+ """Determine the best approach for answering the question."""
163
  question_lower = question.lower()
164
 
165
  # Check for mathematical operations
166
+ math_keywords = ['calculate', 'compute', 'add', 'subtract', 'multiply', 'divide', 'math', 'equation', '+', '-', '*', '/', '%']
167
  if any(keyword in question_lower for keyword in math_keywords):
168
  return 'calculation'
169
 
170
  # Check for research-related queries
171
+ research_keywords = ['search', 'find', 'research', 'information', 'what is', 'who is', 'when', 'where', 'how', 'why']
172
  if any(keyword in question_lower for keyword in research_keywords):
173
  return 'research'
174
 
175
+ # Check for academic/scientific queries
176
+ academic_keywords = ['paper', 'study', 'research', 'academic', 'scientific', 'arxiv', 'journal']
177
+ if any(keyword in question_lower for keyword in academic_keywords):
178
+ return 'academic'
179
+
180
  return 'general'
181
 
182
  def __call__(self, question: str) -> str:
 
184
  try:
185
  print(f"Processing question: {question[:100]}...")
186
 
187
+ # Determine the best approach for this question
188
+ approach = self._determine_approach(question)
189
+ print(f"Selected approach: {approach}")
 
 
190
 
191
+ # Create a comprehensive prompt based on the approach
192
+ if approach == 'calculation':
193
+ enhanced_question = f"""
194
+ Solve this mathematical problem step by step:
195
 
196
+ {question}
197
 
198
+ Use the calculator tool if needed for complex calculations. Show your work clearly.
199
+ """
200
+ elif approach == 'research':
201
+ enhanced_question = f"""
202
+ Research and provide comprehensive information about:
203
+
204
+ {question}
205
+
206
+ Use Wikipedia search and web search tools to gather current and accurate information.
207
+ Cite your sources and provide detailed explanations.
208
+ """
209
+ elif approach == 'academic':
210
+ enhanced_question = f"""
211
+ Find academic and scientific information about:
212
+
213
+ {question}
214
+
215
+ Use ArXiv search and other research tools to find relevant academic papers and studies.
216
+ Provide citations and summarize key findings.
217
+ """
 
 
 
 
 
 
 
218
  else:
219
+ enhanced_question = f"""
220
+ Provide a comprehensive answer to:
221
+
222
+ {question}
223
+
224
+ Use appropriate tools as needed (calculator, search tools) to provide accurate information.
225
+ """
226
 
227
+ # Use the agent to process the question
228
+ result = self.agent.run(enhanced_question)
229
+
230
+ print(f"Generated answer: {str(result)[:200]}...")
231
+ return str(result)
232
 
233
  except Exception as e:
234
  error_msg = f"Error processing question: {str(e)}"
235
  print(error_msg)
236
+ # Provide a fallback response
237
+ try:
238
+ # Try a simple LLM response without tools
239
+ fallback_result = self.llm.invoke([HumanMessage(content=question)])
240
+ return fallback_result.content
241
+ except Exception as fallback_error:
242
+ return f"Error: Unable to process question. {str(e)}"
243
+
244
+ def reset_memory(self):
245
+ """Reset the conversation memory."""
246
+ self.memory.clear()
247
 
248
  # Test function
249
+ def test_langchain_agent():
250
+ """Test the LangChain agent with sample questions."""
251
+ agent = LangChainAgent(provider="google")
252
 
253
  test_questions = [
254
  "What is 25 * 34?",
 
262
  answer = agent(question)
263
  print(f"Answer: {answer}")
264
  print("-" * 50)
265
+ agent.reset_memory() # Reset memory between questions for testing
266
 
267
  if __name__ == "__main__":
268
+ test_langchain_agent()