File size: 3,572 Bytes
bf3b3ab
 
 
 
 
 
ac0a470
2c430ef
67e33c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72f1955
 
da074ea
67e33c2
4d37ed9
 
 
 
 
b9fad49
67e33c2
 
 
f86c996
 
84c0505
2816a42
84c0505
 
67e33c2
 
84c0505
 
67e33c2
84c0505
 
 
 
 
 
 
 
 
 
 
 
be13701
 
 
e98ca85
be13701
 
 
 
 
 
 
84c0505
be13701
 
 
84c0505
be13701
 
 
 
 
 
 
 
84c0505
 
 
be13701
 
84c0505
 
 
be13701
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.prebuilt import ToolNode
from utils.state_utils import AgentState
from tools.financial_tools import time_value_tool



# # LLL instantation

# llm = ChatOllama(model="qwen3:4b", temperature=0)
# llm_instantiated = llm.bind_tools(    
#     [time_value_tool],
#     tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
# )

# def agent_node(state: AgentState):
#     response = llm_instantiated.invoke(state["messages"])
#     if not (hasattr(response, 'tool_calls') and response.tool_calls):
#         error_message = AIMessage(content="Error: Model failed to generate tool call.")
#         return {"messages": [error_message]}
#     return {"messages": [response]}

import os
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint

repo_id= "google/gemma-3n-e4b-it" #"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" #"google/gemma-2b-it"

llm_endpoint = HuggingFaceEndpoint(
    repo_id=repo_id,
    huggingfacehub_api_token=os.environ["HF_TOKEN"],
    max_new_tokens=1024,
    temperature=0.0
)  

llm = ChatHuggingFace(llm=llm_endpoint)
llm_instantiated = llm.bind_tools(
    [time_value_tool],
    tool_choice={"type": "function", "function": {"name": "time_value_tool"}}
)

def agent_node(state: AgentState):
    response = llm_instantiated.invoke(state["messages"])
    if not hasattr(response, "tool_calls") or not response.tool_calls:
        return {"messages": [AIMessage(content="Error: No tool call generated")]}
    return {"messages": [response]}


# Tool node executes the tool
tool_node = ToolNode([time_value_tool])

# Factor to output mapping
F_MAPPING = {
    "P/F": "PV", "P/A": "PV", "P/G": "PV",
    "F/P": "FV", "F/A": "FV", "F/G": "FV",
    "A/P": "Annual", "A/F": "Annual", "A/G": "Annual"
}

def format_output(state: AgentState):
    try:
        # The last message should be the ToolMessage (from the tool node)
        if not state["messages"] or not isinstance(state["messages"][-1], ToolMessage):
            return {"output": {"error": "No tool result found in the last message"}}
        
        tool_message = state["messages"][-1]
        # Parse the content of the tool message as JSON
        tool_result = json.loads(tool_message.content)
        
        # The second last message should be the AIMessage with the tool call
        if len(state["messages"]) < 2 or not isinstance(state["messages"][-2], AIMessage):
            return {"output": {"error": "No AI message (with tool call) found before the tool message"}}
        
        ai_message = state["messages"][-2]
        if not ai_message.tool_calls:
            return {"output": {"error": "The AI message does not contain tool calls"}}
        
        # We take the first tool call (since we forced one tool)
        tool_call = ai_message.tool_calls
        args = tool_call["args"]
        
        # Get the factor type from the args
        factor_type = args["F"]
        if factor_type not in F_MAPPING:
            return {"output": {"error": f"Unrecognized factor type: {factor_type}"}}
        
        result_key = F_MAPPING[factor_type]
        if result_key not in tool_result:
            return {"output": {"error": f"Expected key {result_key} not found in tool result"}}
        
        value = tool_result[result_key]
        return {"output": {result_key: round(float(value), 2)}}
        
    except (KeyError, TypeError, json.JSONDecodeError, IndexError) as e:
        return {"output": {"error": f"Result formatting failed: {str(e)}"}}