Mohammad Haghir commited on
Commit
b06e6a2
·
1 Parent(s): 9e00b70

tool update

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -24,9 +24,9 @@ from agent_utils import wiki_ret, arxiv_ret, tavily_ret, handle_file_tool
24
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
25
 
26
  groq_api_key = os.getenv("GROQ_API_KEY")
27
-
28
  llm = ChatGroq(api_key=groq_api_key, model="gemma2-9b-it")
29
-
30
  class GraphState(TypedDict):
31
  messages: str #Annotated[list, operator.add]
32
  context: str
@@ -70,7 +70,7 @@ class BasicAgent:
70
  # Just make up a task_id.
71
  # Call the LLM
72
  messages = [HumanMessage(content=prompt)]
73
- response = (llm.invoke(messages)).content
74
  # cleaned_text = re.sub(r"<think>.*?</think>", "", response.content, flags=re.DOTALL)
75
 
76
  # json_start = response.find('{')
@@ -82,7 +82,7 @@ class BasicAgent:
82
  def create_graph(self):
83
  builder = StateGraph(GraphState)
84
  builder.add_node("agent", self.agent)
85
- builder.add_node("tools", ToolNode(tools = [wiki_ret, arxiv_ret, tavily_ret, handle_file_tool]))
86
 
87
  builder.add_edge(START, "agent")
88
  builder.add_conditional_edges("agent", tools_condition)
 
24
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
25
 
26
  groq_api_key = os.getenv("GROQ_API_KEY")
27
+ tools = [wiki_ret, arxiv_ret, tavily_ret, handle_file_tool]
28
  llm = ChatGroq(api_key=groq_api_key, model="gemma2-9b-it")
29
+ llm_with_tools = llm.bind_tools(tools)
30
  class GraphState(TypedDict):
31
  messages: str #Annotated[list, operator.add]
32
  context: str
 
70
  # Just make up a task_id.
71
  # Call the LLM
72
  messages = [HumanMessage(content=prompt)]
73
+ response = (llm_with_tools.invoke(messages)).content
74
  # cleaned_text = re.sub(r"<think>.*?</think>", "", response.content, flags=re.DOTALL)
75
 
76
  # json_start = response.find('{')
 
82
  def create_graph(self):
83
  builder = StateGraph(GraphState)
84
  builder.add_node("agent", self.agent)
85
+ builder.add_node("tools", ToolNode(tools = tools))
86
 
87
  builder.add_edge(START, "agent")
88
  builder.add_conditional_edges("agent", tools_condition)