File size: 3,733 Bytes
61411b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import annotations

import logging
from typing import Any, Dict

from langgraph.graph import END, START, StateGraph

from ai_business_automation_agent.agents.decision_agent import decision_route, run_decision_agent
from ai_business_automation_agent.agents.extraction_agent import run_extraction_agent
from ai_business_automation_agent.agents.reporting_agent import run_reporting_agent
from ai_business_automation_agent.agents.validation_agent import run_validation_agent
from ai_business_automation_agent.agents.vendor_verification_agent import run_vendor_verification_agent
from ai_business_automation_agent.llm import get_groq_llm
from ai_business_automation_agent.tools.erp_tool import simulate_erp_update
from ai_business_automation_agent.tools.web_search_tool import TavilyWebSearchTool
from ai_business_automation_agent.utils import append_agent_log
from ai_business_automation_agent.workflow.state_schema import InvoiceState

logger = logging.getLogger(__name__)


def build_workflow():
    llm = get_groq_llm(model="llama-3.3-70b-versatile", temperature=0.0)
    try:
        web_search = TavilyWebSearchTool()
    except Exception as e:
        logger.warning("Tavily web search disabled: %s", e)
        web_search = None

    def extraction_node(state: Dict[str, Any]) -> Dict[str, Any]:
        return run_extraction_agent(state, llm)

    def vendor_node(state: Dict[str, Any]) -> Dict[str, Any]:
        return run_vendor_verification_agent(state, llm, web_search)

    def validation_node(state: Dict[str, Any]) -> Dict[str, Any]:
        return run_validation_agent(state, llm)

    def decision_node(state: Dict[str, Any]) -> Dict[str, Any]:
        return run_decision_agent(state, llm)

    def erp_node(state: Dict[str, Any]) -> Dict[str, Any]:
        extracted = state.get("extracted_data") or {}
        try:
            status = simulate_erp_update(extracted)
            updates = {"erp_update_status": status}
            updates.update(append_agent_log(state, agent="erp_tool", event="ok", payload=status))
            return updates
        except Exception as e:
            err = {"status": "failed", "message": str(e)}
            updates = {"erp_update_status": err}
            updates.update(append_agent_log(state, agent="erp_tool", event="error", payload=err))
            return updates

    def reporting_node(state: Dict[str, Any]) -> Dict[str, Any]:
        return run_reporting_agent(state, llm)

    graph = StateGraph(InvoiceState)

    graph.add_node("extraction_agent", extraction_node)
    graph.add_node("vendor_verification_agent", vendor_node)
    graph.add_node("validation_agent", validation_node)
    graph.add_node("decision_agent", decision_node)
    graph.add_node("erp_update_tool", erp_node)
    graph.add_node("reporting_agent", reporting_node)

    graph.add_edge(START, "extraction_agent")
    graph.add_edge("extraction_agent", "vendor_verification_agent")
    graph.add_edge("vendor_verification_agent", "validation_agent")
    graph.add_edge("validation_agent", "decision_agent")

    graph.add_conditional_edges(
        "decision_agent",
        decision_route,
        {
            "approved": "erp_update_tool",
            "manual_review": "reporting_agent",
            "rejected": "reporting_agent",
        },
    )

    graph.add_edge("erp_update_tool", "reporting_agent")
    graph.add_edge("reporting_agent", END)

    return graph.compile()


def run_workflow(email_content: str) -> Dict[str, Any]:
    app = build_workflow()
    initial_state: InvoiceState = {"email_content": email_content, "agent_logs": []}
    return app.invoke(initial_state)