sqfoo commited on
Commit
f8736ae
·
verified ·
1 Parent(s): 10f2f8e

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +44 -16
agent.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
  from dotenv import load_dotenv
3
  from typing import TypedDict, List, Dict, Any, Optional
 
4
  from langchain.agents import create_tool_calling_agent, AgentExecutor, initialize_agent
5
  from langchain_google_genai import ChatGoogleGenerativeAI
6
  from langchain_groq import ChatGroq
7
  from langchain_core.tools import tool
8
- from langchain_core.messages import HumanMessage
9
  from langchain_core.prompts import ChatPromptTemplate
10
 
11
  # 1. Web Browsing
@@ -231,24 +232,24 @@ def divide(a: float, b: float) -> float:
231
  # ("human", f"Question: {question}\nReport to validate: {final_answer}")
232
  class BasicAgent:
233
  def __init__(self):
234
- # self.model = ChatGoogleGenerativeAI(
235
- # model="gemini-2.0-flash-lite",
236
- # temperature=0,
237
- # max_tokens=128,
238
- # timeout=None,
239
- # max_retries=2,
240
- # google_api_key=os.getenv("GEMINI_API_KEY"),
241
- # # other params...
242
- # )
243
- self.model = ChatGroq(
244
- model="qwen-qwq-32b",
245
  temperature=0,
246
  max_tokens=128,
247
  timeout=None,
248
  max_retries=2,
249
- groq_api_key=os.getenv("GROQ_API_KEY")
250
  # other params...
251
  )
 
 
 
 
 
 
 
 
 
252
  # System Prompt for few shot prompting
253
  self.sys_prompt = """"
254
  You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
@@ -275,6 +276,9 @@ class BasicAgent:
275
  If Task ID is included in the question, remember to call the relevant read tools [ie. read_file, excel_read, csv_read, mp3_listen, image_caption]
276
  """
277
  self.tools = [duckduck_websearch, serper_websearch, visit_webpage, wiki_search, text_splitter, youtube_transcript, read_file, excel_read, csv_read, mp3_listen, image_caption, python_tool]
 
 
 
278
  self.prompt = ChatPromptTemplate.from_messages([
279
  ("system", self.sys_prompt),
280
  ("human", "{input}")
@@ -287,6 +291,7 @@ class BasicAgent:
287
  system_prompt=self.prompt,
288
  handle_parsing_errors="Check your output and make sure it conforms, use the Action/Action Input syntax"
289
  )
 
290
  print("BasicAgent initialized.")
291
 
292
  def __call__(self, task: dict) -> str:
@@ -296,10 +301,33 @@ class BasicAgent:
296
  # fixed_answer = response['message'][-1].content
297
 
298
  if file_name == "" or file_name is None:
299
- fixed_answer = self.agent.run(question)
300
  else:
301
- fixed_answer = self.agent.run(f'{question} with TASK-ID: {task_id}')
 
302
  # fixed_answer = "This is a default answer."
 
 
 
 
303
  print(f"Agent returning fixed answer: {fixed_answer}")
304
  time.sleep(60)
305
- return fixed_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  from typing import TypedDict, List, Dict, Any, Optional
4
+ from langgraph.graph import StateGraph, START, END, MessagesState
5
  from langchain.agents import create_tool_calling_agent, AgentExecutor, initialize_agent
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_groq import ChatGroq
8
  from langchain_core.tools import tool
9
+ from langchain_core.messages import HumanMessage, SystemMessage
10
  from langchain_core.prompts import ChatPromptTemplate
11
 
12
  # 1. Web Browsing
 
232
  # ("human", f"Question: {question}\nReport to validate: {final_answer}")
233
  class BasicAgent:
234
  def __init__(self):
235
+ self.model = ChatGoogleGenerativeAI(
236
+ model="gemini-2.0-flash-lite",
 
 
 
 
 
 
 
 
 
237
  temperature=0,
238
  max_tokens=128,
239
  timeout=None,
240
  max_retries=2,
241
+ google_api_key=os.getenv("GEMINI_API_KEY"),
242
  # other params...
243
  )
244
+ # self.model = ChatGroq(
245
+ # model="qwen-qwq-32b",
246
+ # temperature=0,
247
+ # max_tokens=128,
248
+ # timeout=None,
249
+ # max_retries=2,
250
+ # groq_api_key=os.getenv("GROQ_API_KEY")
251
+ # # other params...
252
+ # )
253
  # System Prompt for few shot prompting
254
  self.sys_prompt = """"
255
  You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
 
276
  If Task ID is included in the question, remember to call the relevant read tools [ie. read_file, excel_read, csv_read, mp3_listen, image_caption]
277
  """
278
  self.tools = [duckduck_websearch, serper_websearch, visit_webpage, wiki_search, text_splitter, youtube_transcript, read_file, excel_read, csv_read, mp3_listen, image_caption, python_tool]
279
+ self.binded_model = self.model.bind_tools(self.tools)
280
+ self.sys_msg = SystemMessage(content=self.sys_prompt)
281
+
282
  self.prompt = ChatPromptTemplate.from_messages([
283
  ("system", self.sys_prompt),
284
  ("human", "{input}")
 
291
  system_prompt=self.prompt,
292
  handle_parsing_errors="Check your output and make sure it conforms, use the Action/Action Input syntax"
293
  )
294
+ self.app = self.__graph_compile__()
295
  print("BasicAgent initialized.")
296
 
297
  def __call__(self, task: dict) -> str:
 
301
  # fixed_answer = response['message'][-1].content
302
 
303
  if file_name == "" or file_name is None:
304
+ question = question
305
  else:
306
+ question = f"{question} with TASK-ID: {task_id}"
307
+ # fixed_answer = self.agent.run(f'{question} with TASK-ID: {task_id}')
308
  # fixed_answer = "This is a default answer."
309
+ # fixed_answer = self.agent.run(question)
310
+ human_message = [HumanMessage(content=question)]
311
+ messages = self.app.invoke({"message": human_message})
312
+ fixed_answer = messages['messages'][-1].content
313
  print(f"Agent returning fixed answer: {fixed_answer}")
314
  time.sleep(60)
315
+ return fixed_answer
316
+
317
+ def __graph_compile__(self):
318
+ builder = StateGraph(MessagesState)
319
+ builder.add_node("assistant", self.assistant)
320
+ builder.add_node("tools", ToolNode(self.tools))
321
+ builder.add_edge(START, "assistant")
322
+ builder.add_conditional_edges(
323
+ "assistant",
324
+ # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
325
+ # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
326
+ tools_condition,
327
+ )
328
+ builder.add_edge("tools", "assistant")
329
+ return builder.compile() # Compile graph
330
+
331
+ def assistant(self, state: MessagesState):
332
+ """Assistant Node"""
333
+ return {"message": [self.binded_model.invoke([self.sys_msg] + state["message"])]}