Shago commited on
Commit
f86c996
·
verified ·
1 Parent(s): f183d0d

Upload agents_nodes.py

Browse files
Files changed (1) hide show
  1. agents/agents_nodes.py +9 -59
agents/agents_nodes.py CHANGED
@@ -1,76 +1,26 @@
1
  import json
2
  from langchain_core.messages import AIMessage, ToolMessage
3
  from langgraph.prebuilt import ToolNode
4
- from langchain_huggingface import HuggingFacePipeline
5
  from utils.state_utils import AgentState
6
  from tools.financial_tools import time_value_tool
7
- from transformers import pipeline
8
- import torch
9
 
10
- time_value_schema = {
11
- "name": "time_value_tool",
12
- "description": "Computes time value of money factors using financial formulas",
13
- "parameters": {
14
- "type": "object",
15
- "properties": {
16
- "CF": {"type": "number"},
17
- "F": {"type": "string"},
18
- "i": {"type": "number"},
19
- "n": {"type": "number"},
20
- "g": {"type": "number", "nullable": True}
21
- },
22
- "required": ["CF", "F", "i", "n"]
23
- }
24
- }
25
-
26
-
27
-
28
- text_generator = pipeline(
29
- "text-generation", # Task type
30
- #model="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
31
- # model="google/gemma-3n-E2B-it",
32
- model="google/gemma-3n-e4b-it",
33
- #model="Qwen/Qwen3-Embedding-0.6B",
34
- device= "cpu",
35
- torch_dtype=torch.bfloat16,
36
- max_new_tokens=200 # Limit output length
37
- )
38
 
39
- llm = HuggingFacePipeline(pipeline=text_generator)
40
 
41
- llm = HuggingFacePipeline(pipeline=text_generator)
42
- # llm_instantiated = llm.bind_tools(
43
- # [time_value_tool],
44
- # tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
45
- # )
46
-
47
- llm_instantiated = llm.bind(
48
- tools=[time_value_schema],
49
- tool_choice={
50
- "type": "function",
51
- "function": {"name": "time_value_tool"}
52
- }
53
  )
54
 
55
-
56
- # def agent_node(state: AgentState):
57
- # response = llm_instantiated.invoke(state["messages"])
58
- # if not (hasattr(response, 'tool_calls') and response.tool_calls):
59
- # error_message = AIMessage(content="Error: Model failed to generate tool call.")
60
- # return {"messages": [error_message]}
61
- # return {"messages": [response]}
62
-
63
  def agent_node(state: AgentState):
64
  response = llm_instantiated.invoke(state["messages"])
65
- if not getattr(response, 'tool_calls', None):
66
- return {
67
- "messages": [response],
68
- "output": {"error": "Failed to generate tool call"}
69
- }
70
  return {"messages": [response]}
71
 
72
-
73
-
74
  # Tool node executes the tool
75
  tool_node = ToolNode([time_value_tool])
76
 
 
1
  import json
2
  from langchain_core.messages import AIMessage, ToolMessage
3
  from langgraph.prebuilt import ToolNode
 
4
  from utils.state_utils import AgentState
5
  from tools.financial_tools import time_value_tool
6
+ from langchain_ollama import ChatOllama
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # LLL instantation
10
 
11
+ llm = ChatOllama(model="qwen3:4b", temperature=0)
12
+ llm_instantiated = llm.bind_tools(
13
+ [time_value_tool],
14
+ tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
 
 
 
 
 
 
 
 
15
  )
16
 
 
 
 
 
 
 
 
 
17
  def agent_node(state: AgentState):
18
  response = llm_instantiated.invoke(state["messages"])
19
+ if not (hasattr(response, 'tool_calls') and response.tool_calls):
20
+ error_message = AIMessage(content="Error: Model failed to generate tool call.")
21
+ return {"messages": [error_message]}
 
 
22
  return {"messages": [response]}
23
 
 
 
24
  # Tool node executes the tool
25
  tool_node = ToolNode([time_value_tool])
26