Update agents/agents_nodes.py
Browse files- agents/agents_nodes.py +4 -50
agents/agents_nodes.py
CHANGED
|
@@ -1,17 +1,3 @@
|
|
| 1 |
-
import json
|
| 2 |
-
from transformers import pipeline
|
| 3 |
-
import torch
|
| 4 |
-
from langchain_core.messages import AIMessage, ToolMessage
|
| 5 |
-
from langchain_huggingface import HuggingFacePipeline, ChatHuggingFace,HuggingFaceEndpoint
|
| 6 |
-
from langchain_core.runnables import RunnableLambda
|
| 7 |
-
from langgraph.prebuilt import ToolNode
|
| 8 |
-
from utils.state_utils import AgentState
|
| 9 |
-
from tools.financial_tools import time_value_tool
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
# LLM instantation
|
| 14 |
-
|
| 15 |
text_generator = pipeline(
|
| 16 |
"text-generation", # Task type
|
| 17 |
#model="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
|
|
@@ -25,41 +11,13 @@ text_generator = pipeline(
|
|
| 25 |
|
| 26 |
llm = HuggingFacePipeline(pipeline=text_generator)
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# llm_instantiated = llm.bind(
|
| 34 |
-
# tools=[{"type": "function", "function": time_value_tool.get_schema()}],
|
| 35 |
-
# tool_choice="auto"
|
| 36 |
-
# )
|
| 37 |
-
|
| 38 |
-
llm_instantiated = llm.bind(
|
| 39 |
-
tools=[{
|
| 40 |
-
"type": "function",
|
| 41 |
-
"function": {
|
| 42 |
-
"name": time_value_tool.name,
|
| 43 |
-
"description": time_value_tool.description,
|
| 44 |
-
"parameters": time_value_tool.args
|
| 45 |
-
}
|
| 46 |
-
}],
|
| 47 |
-
tool_choice={
|
| 48 |
-
"type": "function",
|
| 49 |
-
"function": {"name": time_value_tool.name}
|
| 50 |
-
}
|
| 51 |
)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
def agent_node(state: AgentState):
|
| 55 |
response = llm_instantiated.invoke(state["messages"])
|
| 56 |
-
if response.tool_calls:
|
| 57 |
-
tool_name = response.tool_calls["name"]
|
| 58 |
-
if tool_name != "time_value_tool":
|
| 59 |
-
return {
|
| 60 |
-
"output": {"error": f"Invalid tool call: {tool_name}"},
|
| 61 |
-
"messages": [response]
|
| 62 |
-
}
|
| 63 |
if not (hasattr(response, 'tool_calls') and response.tool_calls):
|
| 64 |
error_message = AIMessage(content="Error: Model failed to generate tool call.")
|
| 65 |
return {"messages": [error_message]}
|
|
@@ -75,12 +33,8 @@ F_MAPPING = {
|
|
| 75 |
"A/P": "Annual", "A/F": "Annual", "A/G": "Annual"
|
| 76 |
}
|
| 77 |
|
| 78 |
-
|
| 79 |
def format_output(state: AgentState):
|
| 80 |
try:
|
| 81 |
-
if state.get("output"):
|
| 82 |
-
return state
|
| 83 |
-
|
| 84 |
# The last message should be the ToolMessage (from the tool node)
|
| 85 |
if not state["messages"] or not isinstance(state["messages"][-1], ToolMessage):
|
| 86 |
return {"output": {"error": "No tool result found in the last message"}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
text_generator = pipeline(
|
| 2 |
"text-generation", # Task type
|
| 3 |
#model="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
|
|
|
|
| 11 |
|
| 12 |
llm = HuggingFacePipeline(pipeline=text_generator)
|
| 13 |
|
| 14 |
+
llm = HuggingFacePipeline(pipeline=text_generator)
|
| 15 |
+
llm_instantiated = llm.bind_tools(
|
| 16 |
+
[time_value_tool],
|
| 17 |
+
tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
)
|
|
|
|
|
|
|
| 19 |
def agent_node(state: AgentState):
|
| 20 |
response = llm_instantiated.invoke(state["messages"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
if not (hasattr(response, 'tool_calls') and response.tool_calls):
|
| 22 |
error_message = AIMessage(content="Error: Model failed to generate tool call.")
|
| 23 |
return {"messages": [error_message]}
|
|
|
|
| 33 |
"A/P": "Annual", "A/F": "Annual", "A/G": "Annual"
|
| 34 |
}
|
| 35 |
|
|
|
|
| 36 |
def format_output(state: AgentState):
|
| 37 |
try:
|
|
|
|
|
|
|
|
|
|
| 38 |
# The last message should be the ToolMessage (from the tool node)
|
| 39 |
if not state["messages"] or not isinstance(state["messages"][-1], ToolMessage):
|
| 40 |
return {"output": {"error": "No tool result found in the last message"}}
|