Shago commited on
Commit
67e33c2
·
verified ·
1 Parent(s): f86c996

Update agents/agents_nodes.py

Browse files
Files changed (1) hide show
  1. agents/agents_nodes.py +37 -7
agents/agents_nodes.py CHANGED
@@ -3,24 +3,54 @@ from langchain_core.messages import AIMessage, ToolMessage
3
  from langgraph.prebuilt import ToolNode
4
  from utils.state_utils import AgentState
5
  from tools.financial_tools import time_value_tool
6
- from langchain_ollama import ChatOllama
7
 
8
 
9
- # LLL instantation
10
 
11
- llm = ChatOllama(model="qwen3:4b", temperature=0)
12
- llm_instantiated = llm.bind_tools(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  [time_value_tool],
14
  tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
15
  )
16
 
17
  def agent_node(state: AgentState):
18
  response = llm_instantiated.invoke(state["messages"])
19
- if not (hasattr(response, 'tool_calls') and response.tool_calls):
20
- error_message = AIMessage(content="Error: Model failed to generate tool call.")
21
- return {"messages": [error_message]}
22
  return {"messages": [response]}
23
 
 
24
  # Tool node executes the tool
25
  tool_node = ToolNode([time_value_tool])
26
 
 
3
  from langgraph.prebuilt import ToolNode
4
  from utils.state_utils import AgentState
5
  from tools.financial_tools import time_value_tool
 
6
 
7
 
 
8
 
9
+ # # LLL instantation
10
+
11
+ # llm = ChatOllama(model="qwen3:4b", temperature=0)
12
+ # llm_instantiated = llm.bind_tools(
13
+ # [time_value_tool],
14
+ # tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
15
+ # )
16
+
17
+ # def agent_node(state: AgentState):
18
+ # response = llm_instantiated.invoke(state["messages"])
19
+ # if not (hasattr(response, 'tool_calls') and response.tool_calls):
20
+ # error_message = AIMessage(content="Error: Model failed to generate tool call.")
21
+ # return {"messages": [error_message]}
22
+ # return {"messages": [response]}
23
+
24
+ import os
25
+ from langchain_community.chat_models import ChatHuggingFace
26
+ from langchain_community.llms import HuggingFaceEndpoint
27
+ from langchain_core.messages import AIMessage
28
+
29
+ # Initialize Hugging Face endpoint (replace with your model)
30
+ HF_MODEL = "google/gemma-2b-it"
31
+ llm_endpoint = HuggingFaceEndpoint(
32
+ endpoint_url=f"https://api-inference.huggingface.co/models/{HF_MODEL}",
33
+ huggingfacehub_api_token=os.environ["HF_TOKEN"],
34
+ max_new_tokens=500,
35
+ temperature=0
36
+ )
37
+
38
+ # Wrap in chat model
39
+ llm = ChatHuggingFace(llm=llm_endpoint)
40
+
41
+ # Bind tools to model
42
+ llm_instantiated = llm.bind_tools(
43
  [time_value_tool],
44
  tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
45
  )
46
 
47
  def agent_node(state: AgentState):
48
  response = llm_instantiated.invoke(state["messages"])
49
+ if not hasattr(response, "tool_calls") or not response.tool_calls:
50
+ return {"messages": [AIMessage(content="Error: No tool call generated")]}
 
51
  return {"messages": [response]}
52
 
53
+
54
  # Tool node executes the tool
55
  tool_node = ToolNode([time_value_tool])
56