Shago commited on
Commit
394cce9
·
verified ·
1 Parent(s): 60a9916

Update agents/agents_nodes.py

Browse files
Files changed (1) hide show
  1. agents/agents_nodes.py +4 -50
agents/agents_nodes.py CHANGED
@@ -1,17 +1,3 @@
1
- import json
2
- from transformers import pipeline
3
- import torch
4
- from langchain_core.messages import AIMessage, ToolMessage
5
- from langchain_huggingface import HuggingFacePipeline, ChatHuggingFace,HuggingFaceEndpoint
6
- from langchain_core.runnables import RunnableLambda
7
- from langgraph.prebuilt import ToolNode
8
- from utils.state_utils import AgentState
9
- from tools.financial_tools import time_value_tool
10
-
11
-
12
-
13
- # LLM instantation
14
-
15
  text_generator = pipeline(
16
  "text-generation", # Task type
17
  #model="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
@@ -25,41 +11,13 @@ text_generator = pipeline(
25
 
26
  llm = HuggingFacePipeline(pipeline=text_generator)
27
 
28
- # llm_instantiated = llm.bind(
29
- # functions=[time_value_tool],
30
- # function_call={"name": "time_value_tool"},
31
- # )
32
-
33
- # llm_instantiated = llm.bind(
34
- # tools=[{"type": "function", "function": time_value_tool.get_schema()}],
35
- # tool_choice="auto"
36
- # )
37
-
38
- llm_instantiated = llm.bind(
39
- tools=[{
40
- "type": "function",
41
- "function": {
42
- "name": time_value_tool.name,
43
- "description": time_value_tool.description,
44
- "parameters": time_value_tool.args
45
- }
46
- }],
47
- tool_choice={
48
- "type": "function",
49
- "function": {"name": time_value_tool.name}
50
- }
51
  )
52
-
53
-
54
  def agent_node(state: AgentState):
55
  response = llm_instantiated.invoke(state["messages"])
56
- if response.tool_calls:
57
- tool_name = response.tool_calls["name"]
58
- if tool_name != "time_value_tool":
59
- return {
60
- "output": {"error": f"Invalid tool call: {tool_name}"},
61
- "messages": [response]
62
- }
63
  if not (hasattr(response, 'tool_calls') and response.tool_calls):
64
  error_message = AIMessage(content="Error: Model failed to generate tool call.")
65
  return {"messages": [error_message]}
@@ -75,12 +33,8 @@ F_MAPPING = {
75
  "A/P": "Annual", "A/F": "Annual", "A/G": "Annual"
76
  }
77
 
78
-
79
  def format_output(state: AgentState):
80
  try:
81
- if state.get("output"):
82
- return state
83
-
84
  # The last message should be the ToolMessage (from the tool node)
85
  if not state["messages"] or not isinstance(state["messages"][-1], ToolMessage):
86
  return {"output": {"error": "No tool result found in the last message"}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  text_generator = pipeline(
2
  "text-generation", # Task type
3
  #model="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
 
11
 
12
  llm = HuggingFacePipeline(pipeline=text_generator)
13
 
14
+ llm = HuggingFacePipeline(pipeline=text_generator)
15
+ llm_instantiated = llm.bind_tools(
16
+ [time_value_tool],
17
+ tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
 
 
19
  def agent_node(state: AgentState):
20
  response = llm_instantiated.invoke(state["messages"])
 
 
 
 
 
 
 
21
  if not (hasattr(response, 'tool_calls') and response.tool_calls):
22
  error_message = AIMessage(content="Error: Model failed to generate tool call.")
23
  return {"messages": [error_message]}
 
33
  "A/P": "Annual", "A/F": "Annual", "A/G": "Annual"
34
  }
35
 
 
36
  def format_output(state: AgentState):
37
  try:
 
 
 
38
  # The last message should be the ToolMessage (from the tool node)
39
  if not state["messages"] or not isinstance(state["messages"][-1], ToolMessage):
40
  return {"output": {"error": "No tool result found in the last message"}}