Shago commited on
Commit
2c430ef
·
verified ·
1 Parent(s): 2ab2c6d

Update agents/agents_nodes.py

Browse files
Files changed (1) hide show
  1. agents/agents_nodes.py +21 -22
agents/agents_nodes.py CHANGED
@@ -12,32 +12,31 @@ from tools.financial_tools import time_value_tool
12
 
13
  # LLM instantation
14
 
15
- # text_generator = pipeline(
16
- # "text-generation", # Task type
17
- # # model="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
18
- # # model="google/gemma-3n-E2B-it",
19
- # # model="google/gemma-3n-e4b-it",
20
- # model="Qwen/Qwen3-Embedding-0.6B",
21
- # # device="cuda" if torch.cuda.is_available() else "cpu",
22
- # device= "cpu",
23
- # torch_dtype=torch.bfloat16,
24
- # max_new_tokens=700 # Limit output length
25
- # )
26
-
27
- llm_endpoint = HuggingFaceEndpoint(
28
- endpoint_url="<your_endpoint_url>",
29
- task="text-generation",
30
- max_new_tokens=1024,
31
- do_sample=False
32
  )
33
- # llm = HuggingFacePipeline(pipeline=text_generator)
34
- llm = ChatHuggingFace(llm=llm_endpoint)
35
 
36
- llm_instantiated = llm.bind_tools(
37
- [time_value_tool],
38
- tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
 
 
 
39
  )
40
 
 
 
 
 
 
41
  def agent_node(state: AgentState):
42
  response = llm_instantiated.invoke(state["messages"])
43
  if not (hasattr(response, 'tool_calls') and response.tool_calls):
 
12
 
13
  # LLM instantation
14
 
15
+ text_generator = pipeline(
16
+ "text-generation", # Task type
17
+ # model="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
18
+ # model="google/gemma-3n-E2B-it",
19
+ # model="google/gemma-3n-e4b-it",
20
+ model="Qwen/Qwen3-Embedding-0.6B",
21
+ # device="cuda" if torch.cuda.is_available() else "cpu",
22
+ device= "cpu",
23
+ torch_dtype=torch.bfloat16,
24
+ max_new_tokens=700 # Limit output length
 
 
 
 
 
 
 
25
  )
 
 
26
 
27
+ llm = HuggingFacePipeline(pipeline=text_generator)
28
+
29
+ llm_instantiated = llm.bind(
30
+ functions=[tool_schema],
31
+ function_call={"name": "time_value_tool"},
32
+ stop_sequences=["<|im_end|>"]
33
  )
34
 
35
+ # llm_instantiated = llm.bind_tools(
36
+ # [time_value_tool],
37
+ # tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
38
+ # )
39
+
40
  def agent_node(state: AgentState):
41
  response = llm_instantiated.invoke(state["messages"])
42
  if not (hasattr(response, 'tool_calls') and response.tool_calls):