Aditya0619 commited on
Commit
7a7bd49
·
verified ·
1 Parent(s): a8e6839

Update crewai_agent.py

Browse files
Files changed (1) hide show
  1. crewai_agent.py +42 -18
crewai_agent.py CHANGED
@@ -2,20 +2,23 @@ import os
2
  from typing import List, Dict, Any, Optional
3
  from dotenv import load_dotenv
4
 
 
 
 
5
  from langchain.agents import AgentType, initialize_agent, Tool
6
  from langchain.memory import ConversationBufferMemory
7
- from langchain.schema import BaseMessage, HumanMessage, AIMessage
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
12
- from langchain.tools import tool
13
  from langchain.prompts import PromptTemplate
14
  from langchain.chains import LLMChain
15
 
16
  # Load environment variables
17
  GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
18
- HUGGINGFACE_API_TOKEN= os.getenv('HUGGINGFACE_API_TOKEN')
19
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
20
 
21
 
@@ -80,7 +83,10 @@ def web_search_tool(query: str) -> str:
80
  Formatted web search results
81
  """
82
  try:
83
- search_results = TavilySearchResults(max_results=3,api_key=TAVILY_API_KEY).invoke(query)
 
 
 
84
  formatted_results = "\n\n---\n\n".join([
85
  f'Source: {result.get("url", "")}\n\nContent:\n{result.get("content", "")}'
86
  for result in search_results
@@ -123,20 +129,23 @@ class LangChainAgent:
123
  def _get_llm(self, provider: str):
124
  """Get the specified LLM."""
125
  if provider == "google":
 
 
126
  return ChatGoogleGenerativeAI(
127
- model="gemini-2.0-flash",
128
  temperature=0,
129
  max_tokens=2048,
130
- api_key=GOOGLE_API_KEY
131
-
132
  )
133
  elif provider == "huggingface":
 
 
134
  return ChatHuggingFace(
135
  llm=HuggingFaceEndpoint(
136
  repo_id="microsoft/DialoGPT-medium",
137
  temperature=0,
138
  max_length=2048,
139
- api_key=HUGGINGFACE_API_TOKEN
140
  ),
141
  )
142
  else:
@@ -153,16 +162,21 @@ class LangChainAgent:
153
 
154
  def _create_agent(self):
155
  """Create the LangChain agent with tools."""
156
- return initialize_agent(
157
- tools=self.tools,
158
- llm=self.llm,
159
- agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
160
- memory=self.memory,
161
- verbose=True,
162
- handle_parsing_errors=True,
163
- max_iterations=5,
164
- early_stopping_method="generate"
165
- )
 
 
 
 
 
166
 
167
  def _determine_approach(self, question: str) -> str:
168
  """Determine the best approach for answering the question."""
@@ -190,6 +204,15 @@ class LangChainAgent:
190
  try:
191
  print(f"Processing question: {question[:100]}...")
192
 
 
 
 
 
 
 
 
 
 
193
  # Determine the best approach for this question
194
  approach = self._determine_approach(question)
195
  print(f"Selected approach: {approach}")
@@ -272,3 +295,4 @@ def test_langchain_agent():
272
 
273
  if __name__ == "__main__":
274
  test_langchain_agent()
 
 
2
  from typing import List, Dict, Any, Optional
3
  from dotenv import load_dotenv
4
 
5
+ # Load environment variables from .env file
6
+ load_dotenv()
7
+
8
  from langchain.agents import AgentType, initialize_agent, Tool
9
  from langchain.memory import ConversationBufferMemory
10
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
11
  from langchain_google_genai import ChatGoogleGenerativeAI
12
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
13
  from langchain_community.tools.tavily_search import TavilySearchResults
14
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
15
+ from langchain_core.tools import tool
16
  from langchain.prompts import PromptTemplate
17
  from langchain.chains import LLMChain
18
 
19
  # Load environment variables
20
  GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
21
+ HUGGINGFACE_API_TOKEN = os.getenv('HUGGINGFACE_API_TOKEN')
22
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
23
 
24
 
 
83
  Formatted web search results
84
  """
85
  try:
86
+ if not TAVILY_API_KEY:
87
+ return "Error: TAVILY_API_KEY not found in environment variables"
88
+
89
+ search_results = TavilySearchResults(max_results=3, api_key=TAVILY_API_KEY).invoke(query)
90
  formatted_results = "\n\n---\n\n".join([
91
  f'Source: {result.get("url", "")}\n\nContent:\n{result.get("content", "")}'
92
  for result in search_results
 
129
  def _get_llm(self, provider: str):
130
  """Get the specified LLM."""
131
  if provider == "google":
132
+ if not GOOGLE_API_KEY:
133
+ raise ValueError("GOOGLE_API_KEY not found in environment variables")
134
  return ChatGoogleGenerativeAI(
135
+ model="gemini-1.5-flash",
136
  temperature=0,
137
  max_tokens=2048,
138
+ google_api_key=GOOGLE_API_KEY
 
139
  )
140
  elif provider == "huggingface":
141
+ if not HUGGINGFACE_API_TOKEN:
142
+ raise ValueError("HUGGINGFACE_API_TOKEN not found in environment variables")
143
  return ChatHuggingFace(
144
  llm=HuggingFaceEndpoint(
145
  repo_id="microsoft/DialoGPT-medium",
146
  temperature=0,
147
  max_length=2048,
148
+ huggingfacehub_api_token=HUGGINGFACE_API_TOKEN
149
  ),
150
  )
151
  else:
 
162
 
163
  def _create_agent(self):
164
  """Create the LangChain agent with tools."""
165
+ try:
166
+ return initialize_agent(
167
+ tools=self.tools,
168
+ llm=self.llm,
169
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
170
+ memory=self.memory,
171
+ verbose=True,
172
+ handle_parsing_errors=True,
173
+ max_iterations=3,
174
+ early_stopping_method="generate"
175
+ )
176
+ except Exception as e:
177
+ print(f"Error creating agent: {e}")
178
+ # Return a simple agent without tools as fallback
179
+ return None
180
 
181
  def _determine_approach(self, question: str) -> str:
182
  """Determine the best approach for answering the question."""
 
204
  try:
205
  print(f"Processing question: {question[:100]}...")
206
 
207
+ # If agent initialization failed, use direct LLM
208
+ if self.agent is None:
209
+ print("Agent not available, using direct LLM response")
210
+ try:
211
+ response = self.llm.invoke([HumanMessage(content=question)])
212
+ return response.content
213
+ except Exception as llm_error:
214
+ return f"Error: Unable to process question. {str(llm_error)}"
215
+
216
  # Determine the best approach for this question
217
  approach = self._determine_approach(question)
218
  print(f"Selected approach: {approach}")
 
295
 
296
  if __name__ == "__main__":
297
  test_langchain_agent()
298
+