File size: 2,864 Bytes
2f16cc8
dea72cd
 
 
 
 
 
 
 
7e453aa
dea72cd
 
 
 
 
 
 
 
 
 
2f16cc8
 
dea72cd
 
 
 
 
29ee329
 
 
 
 
aad415a
 
 
dea72cd
29ee329
dea72cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e453aa
dea72cd
 
29ee329
dea72cd
 
 
 
 
 
 
 
2f16cc8
 
dea72cd
29ee329
 
aad415a
 
 
 
dea72cd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import TypedDict, Dict, Any, List, Annotated, Optional
import time, uuid, os
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from agents.simple_tools import (
    generate_notes_full_pipeline_from_path,
    generate_balance_sheet,
    generate_pnl_statement,
    generate_cash_flow_statement,
    generate_llm_notes,
)

class FinancialAgentState(TypedDict):
    messages: Annotated[List[BaseMessage], "History"]
    file_path: str
    result: Dict[str, Any]
    status: str
    start_time: float
    end_time: float
    error: str
    user_api_key: Optional[str]
    feedback_context: Optional[Dict[str, Any]]

def make_workflow(tool_func):
    def node(state: FinancialAgentState) -> FinancialAgentState:
        state["start_time"] = time.time()
        try:
            # Prepare parameters for tool invocation
            tool_params = {"file_path": state["file_path"]}
            # Add feedback_context if available
            if "feedback_context" in state:
                tool_params["feedback_context"] = state["feedback_context"]
            # Add user_api_key if available
            if "user_api_key" in state:
                tool_params["user_api_key"] = state["user_api_key"]
            # Use .invoke() to avoid deprecation warning
            result = tool_func.invoke(tool_params)
            state["result"] = result
            state["status"] = "success" if result.get("status") == "success" else "error"
            state["error"] = result.get("error", "")
        except Exception as e:
            state["status"] = "error"
            state["error"] = str(e)
        state["end_time"] = time.time()
        return state

    wf = StateGraph(FinancialAgentState)
    wf.add_node("run", node)
    wf.set_entry_point("run")
    wf.add_edge("run", END)
    return wf.compile()

workflows = {
    "notes": make_workflow(generate_notes_full_pipeline_from_path),
    "pnl": make_workflow(generate_pnl_statement),
    "bs": make_workflow(generate_balance_sheet),
    "cf": make_workflow(generate_cash_flow_statement),
    "notes-llm": make_workflow(generate_llm_notes),
}

def run_workflow(file_path: str, kind: str, **kwargs) -> Dict[str, Any]:
    state = FinancialAgentState(
        messages=[HumanMessage(content=f"Run {kind} for {file_path}")],
        file_path=file_path,
        result={},
        status="",
        start_time=0,
        end_time=0,
        error="",
        user_api_key=None,
        feedback_context=None,
    )
    # Add feedback_context if provided
    if "feedback_context" in kwargs:
        state["feedback_context"] = kwargs["feedback_context"]
    # Add user_api_key if provided
    if "user_api_key" in kwargs:
        state["user_api_key"] = kwargs["user_api_key"]
    final = workflows[kind].invoke(state)
    return final