Update agents/agents_nodes.py
Browse files- agents/agents_nodes.py +37 -7
agents/agents_nodes.py
CHANGED
|
@@ -3,24 +3,54 @@ from langchain_core.messages import AIMessage, ToolMessage
|
|
| 3 |
from langgraph.prebuilt import ToolNode
|
| 4 |
from utils.state_utils import AgentState
|
| 5 |
from tools.financial_tools import time_value_tool
|
| 6 |
-
from langchain_ollama import ChatOllama
|
| 7 |
|
| 8 |
|
| 9 |
-
# LLL instantation
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
[time_value_tool],
|
| 14 |
tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
|
| 15 |
)
|
| 16 |
|
| 17 |
def agent_node(state: AgentState):
|
| 18 |
response = llm_instantiated.invoke(state["messages"])
|
| 19 |
-
if not
|
| 20 |
-
|
| 21 |
-
return {"messages": [error_message]}
|
| 22 |
return {"messages": [response]}
|
| 23 |
|
|
|
|
| 24 |
# Tool node executes the tool
|
| 25 |
tool_node = ToolNode([time_value_tool])
|
| 26 |
|
|
|
|
| 3 |
from langgraph.prebuilt import ToolNode
|
| 4 |
from utils.state_utils import AgentState
|
| 5 |
from tools.financial_tools import time_value_tool
|
|
|
|
| 6 |
|
| 7 |
|
|
|
|
| 8 |
|
| 9 |
+
# # LLL instantation
|
| 10 |
+
|
| 11 |
+
# llm = ChatOllama(model="qwen3:4b", temperature=0)
|
| 12 |
+
# llm_instantiated = llm.bind_tools(
|
| 13 |
+
# [time_value_tool],
|
| 14 |
+
# tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
|
| 15 |
+
# )
|
| 16 |
+
|
| 17 |
+
# def agent_node(state: AgentState):
|
| 18 |
+
# response = llm_instantiated.invoke(state["messages"])
|
| 19 |
+
# if not (hasattr(response, 'tool_calls') and response.tool_calls):
|
| 20 |
+
# error_message = AIMessage(content="Error: Model failed to generate tool call.")
|
| 21 |
+
# return {"messages": [error_message]}
|
| 22 |
+
# return {"messages": [response]}
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
from langchain_community.chat_models import ChatHuggingFace
|
| 26 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
| 27 |
+
from langchain_core.messages import AIMessage
|
| 28 |
+
|
| 29 |
+
# Initialize Hugging Face endpoint (replace with your model)
|
| 30 |
+
HF_MODEL = "google/gemma-2b-it"
|
| 31 |
+
llm_endpoint = HuggingFaceEndpoint(
|
| 32 |
+
endpoint_url=f"https://api-inference.huggingface.co/models/{HF_MODEL}",
|
| 33 |
+
huggingfacehub_api_token=os.environ["HF_TOKEN"],
|
| 34 |
+
max_new_tokens=500,
|
| 35 |
+
temperature=0
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Wrap in chat model
|
| 39 |
+
llm = ChatHuggingFace(llm=llm_endpoint)
|
| 40 |
+
|
| 41 |
+
# Bind tools to model
|
| 42 |
+
llm_instantiated = llm.bind_tools(
|
| 43 |
[time_value_tool],
|
| 44 |
tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
|
| 45 |
)
|
| 46 |
|
| 47 |
def agent_node(state: AgentState):
|
| 48 |
response = llm_instantiated.invoke(state["messages"])
|
| 49 |
+
if not hasattr(response, "tool_calls") or not response.tool_calls:
|
| 50 |
+
return {"messages": [AIMessage(content="Error: No tool call generated")]}
|
|
|
|
| 51 |
return {"messages": [response]}
|
| 52 |
|
| 53 |
+
|
| 54 |
# Tool node executes the tool
|
| 55 |
tool_node = ToolNode([time_value_tool])
|
| 56 |
|