File size: 3,572 Bytes
bf3b3ab ac0a470 2c430ef 67e33c2 72f1955 da074ea 67e33c2 4d37ed9 b9fad49 67e33c2 f86c996 84c0505 2816a42 84c0505 67e33c2 84c0505 67e33c2 84c0505 be13701 e98ca85 be13701 84c0505 be13701 84c0505 be13701 84c0505 be13701 84c0505 be13701 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import json
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.prebuilt import ToolNode
from utils.state_utils import AgentState
from tools.financial_tools import time_value_tool
# # LLL instantation
# llm = ChatOllama(model="qwen3:4b", temperature=0)
# llm_instantiated = llm.bind_tools(
# [time_value_tool],
# tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
# )
# def agent_node(state: AgentState):
# response = llm_instantiated.invoke(state["messages"])
# if not (hasattr(response, 'tool_calls') and response.tool_calls):
# error_message = AIMessage(content="Error: Model failed to generate tool call.")
# return {"messages": [error_message]}
# return {"messages": [response]}
import os
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
repo_id= "google/gemma-3n-e4b-it" #"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" #"google/gemma-2b-it"
llm_endpoint = HuggingFaceEndpoint(
repo_id=repo_id,
huggingfacehub_api_token=os.environ["HF_TOKEN"],
max_new_tokens=1024,
temperature=0.0
)
llm = ChatHuggingFace(llm=llm_endpoint)
llm_instantiated = llm.bind_tools(
[time_value_tool],
tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
)
def agent_node(state: AgentState):
response = llm_instantiated.invoke(state["messages"])
if not hasattr(response, "tool_calls") or not response.tool_calls:
return {"messages": [AIMessage(content="Error: No tool call generated")]}
return {"messages": [response]}
# Tool node executes the tool
tool_node = ToolNode([time_value_tool])
# Factor to output mapping
F_MAPPING = {
"P/F": "PV", "P/A": "PV", "P/G": "PV",
"F/P": "FV", "F/A": "FV", "F/G": "FV",
"A/P": "Annual", "A/F": "Annual", "A/G": "Annual"
}
def format_output(state: AgentState):
try:
# The last message should be the ToolMessage (from the tool node)
if not state["messages"] or not isinstance(state["messages"][-1], ToolMessage):
return {"output": {"error": "No tool result found in the last message"}}
tool_message = state["messages"][-1]
# Parse the content of the tool message as JSON
tool_result = json.loads(tool_message.content)
# The second last message should be the AIMessage with the tool call
if len(state["messages"]) < 2 or not isinstance(state["messages"][-2], AIMessage):
return {"output": {"error": "No AI message (with tool call) found before the tool message"}}
ai_message = state["messages"][-2]
if not ai_message.tool_calls:
return {"output": {"error": "The AI message does not contain tool calls"}}
# We take the first tool call (since we forced one tool)
tool_call = ai_message.tool_calls
args = tool_call["args"]
# Get the factor type from the args
factor_type = args["F"]
if factor_type not in F_MAPPING:
return {"output": {"error": f"Unrecognized factor type: {factor_type}"}}
result_key = F_MAPPING[factor_type]
if result_key not in tool_result:
return {"output": {"error": f"Expected key {result_key} not found in tool result"}}
value = tool_result[result_key]
return {"output": {result_key: round(float(value), 2)}}
except (KeyError, TypeError, json.JSONDecodeError, IndexError) as e:
return {"output": {"error": f"Result formatting failed: {str(e)}"}} |