File size: 8,898 Bytes
fcf8749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
"""
LangGraph workflow definition for Fair Dispatch allocation.
Orchestrates all agents in a graph with conditional edges and checkpointing.
"""

import os
from typing import Dict, Any, Optional
from datetime import datetime

from langgraph.graph import StateGraph, END

from app.schemas.allocation_state import AllocationState
from app.services.langgraph_nodes import (
    ml_effort_node,
    route_planner_node,
    fairness_check_node,
    fairness_check_2_node,
    route_planner_reoptimize_node,
    select_final_proposal_node,
    driver_liaison_node,
    final_resolution_node,
    explainability_node,
    should_reoptimize,
    has_counter_decisions,
)


def create_allocation_graph(
    checkpointer: Optional[Any] = None,
    enable_gemini: bool = False,
) -> StateGraph:
    """
    Create the allocation workflow graph.
    
    Args:
        checkpointer: Optional LangGraph checkpointer for state persistence
        enable_gemini: If True, add Gemini explainability node
        
    Returns:
        Compiled StateGraph ready for invocation
    """
    # Create the graph with AllocationState
    workflow = StateGraph(AllocationState)
    
    # ==========================================================================
    # Add Nodes
    # ==========================================================================
    
    # Phase 1: ML Effort Agent
    workflow.add_node("ml_effort", ml_effort_node)
    
    # Phase 2: Route Planner Agent (Proposal 1)
    workflow.add_node("route_planner_1", route_planner_node)
    
    # Phase 3: Fairness Manager Agent (Check 1) - renamed to avoid state key conflict
    workflow.add_node("fairness_agent_1", fairness_check_node)
    
    # Phase 3b: Route Planner Re-optimization (Proposal 2)
    workflow.add_node("route_planner_2", route_planner_reoptimize_node)
    
    # Phase 3c: Fairness Manager Agent (Check 2) - renamed to avoid state key conflict
    workflow.add_node("fairness_agent_2", fairness_check_2_node)
    
    # Phase 3d: Select Final Proposal
    workflow.add_node("select_final", select_final_proposal_node)
    
    # Phase 4: Driver Liaison Agent
    workflow.add_node("driver_liaison", driver_liaison_node)
    
    # Phase 5: Final Resolution Agent
    workflow.add_node("final_resolution", final_resolution_node)
    
    # Phase 6: Explainability Agent
    workflow.add_node("explainability", explainability_node)
    
    # Optional: Gemini Explainability Node
    if enable_gemini and os.getenv("GOOGLE_API_KEY"):
        try:
            from app.services.gemini_explain_node import gemini_explain_node
            workflow.add_node("gemini_explain", gemini_explain_node)
        except ImportError:
            pass  # Gemini not available, skip
    
    # ==========================================================================
    # Add Edges
    # ==========================================================================
    
    # Entry point
    workflow.set_entry_point("ml_effort")
    
    # Linear flow: ML Effort -> Route Planner 1
    workflow.add_edge("ml_effort", "route_planner_1")
    
    # Route Planner 1 -> Fairness Agent 1
    workflow.add_edge("route_planner_1", "fairness_agent_1")
    
    # Conditional: Fairness Agent 1 -> Reoptimize or Select Final
    workflow.add_conditional_edges(
        "fairness_agent_1",
        should_reoptimize,
        {
            "reoptimize": "route_planner_2",
            "continue": "select_final",
        }
    )
    
    # Reoptimize path: Route Planner 2 -> Fairness Agent 2 -> Select Final
    workflow.add_edge("route_planner_2", "fairness_agent_2")
    workflow.add_edge("fairness_agent_2", "select_final")
    
    # Select Final -> Driver Liaison
    workflow.add_edge("select_final", "driver_liaison")
    
    # Conditional: Driver Liaison -> Final Resolution or Explainability
    workflow.add_conditional_edges(
        "driver_liaison",
        has_counter_decisions,
        {
            "resolve": "final_resolution",
            "skip": "explainability",
        }
    )
    
    # Final Resolution -> Explainability
    workflow.add_edge("final_resolution", "explainability")
    
    # Explainability -> Gemini or END
    if enable_gemini and os.getenv("GOOGLE_API_KEY"):
        try:
            from app.services.gemini_explain_node import gemini_explain_node
            workflow.add_edge("explainability", "gemini_explain")
            workflow.add_edge("gemini_explain", END)
        except ImportError:
            workflow.add_edge("explainability", END)
    else:
        workflow.add_edge("explainability", END)
    
    # ==========================================================================
    # Compile
    # ==========================================================================
    
    if checkpointer:
        return workflow.compile(checkpointer=checkpointer)
    else:
        return workflow.compile()


# Global graph instance (lazy initialization)
_allocation_graph = None


def clear_allocation_graph() -> None:
    """Clear the cached allocation graph to force recreation."""
    global _allocation_graph
    _allocation_graph = None


def get_allocation_graph(
    checkpointer: Optional[Any] = None,
    enable_gemini: bool = None,
    force_recreate: bool = False,
) -> StateGraph:
    """
    Get or create the allocation graph singleton.
    
    Args:
        checkpointer: Optional checkpointer for persistence
        enable_gemini: Override Gemini setting (defaults to env var)
        force_recreate: If True, recreate graph even if cached
        
    Returns:
        Compiled allocation graph
    """
    global _allocation_graph
    
    if enable_gemini is None:
        enable_gemini = os.getenv("ENABLE_GEMINI_EXPLAIN", "false").lower() == "true"
    
    if _allocation_graph is None or force_recreate:
        _allocation_graph = create_allocation_graph(
            checkpointer=checkpointer,
            enable_gemini=enable_gemini,
        )
    
    return _allocation_graph


async def invoke_allocation_workflow(
    request_dict: Dict[str, Any],
    config_used: Optional[Dict[str, Any]] = None,
    driver_models: list = None,
    route_models: list = None,
    route_dicts: list = None,
    driver_contexts: Dict[str, Dict[str, Any]] = None,
    recovery_targets: Dict[str, Optional[float]] = None,
    allocation_run_id: Optional[str] = None,
    thread_id: Optional[str] = None,
) -> AllocationState:
    """
    Invoke the allocation workflow with the given inputs.
    
    This is the main entry point for running the LangGraph workflow.
    
    Args:
        request_dict: AllocationRequest.dict()
        config_used: Active FairnessConfig snapshot
        driver_models: List of driver model data
        route_models: List of route model data
        route_dicts: List of route dictionaries with packages
        driver_contexts: Dict of driver contexts for liaison agent
        recovery_targets: Recovery effort targets per driver
        allocation_run_id: ID of the AllocationRun for persistence
        thread_id: Thread ID for checkpointing (defaults to allocation_run_id)
        
    Returns:
        Final AllocationState with all agent outputs
    """
    graph = get_allocation_graph(force_recreate=True)  # Force recreate to pick up latest nodes
    
    # Build initial state
    initial_state = AllocationState(
        request=request_dict,
        config_used=config_used or {},
        driver_models=driver_models or [],
        route_models=route_models or [],
        route_dicts=route_dicts or [],
        driver_contexts=driver_contexts or {},
        recovery_targets=recovery_targets or {},
        allocation_run_id=allocation_run_id,
        workflow_start=datetime.utcnow(),
    )
    
    # Prepare config for graph invocation
    config = {}
    if thread_id or allocation_run_id:
        config["configurable"] = {"thread_id": thread_id or allocation_run_id}
    
    # Invoke the graph
    # Note: LangGraph's invoke returns the final state
    final_state_dict = await graph.ainvoke(initial_state.model_dump(), config=config)
    
    # Convert back to AllocationState
    return AllocationState.model_validate(final_state_dict)


def get_workflow_visualization() -> str:
    """
    Get a Mermaid diagram of the workflow for documentation.
    
    Returns:
        Mermaid diagram string
    """
    return """
```mermaid
graph TD
    A[Entry: ml_effort] --> B[route_planner_1]
    B --> C[fairness_agent_1]
    C --> D{should_reoptimize?}
    D -->|reoptimize| E[route_planner_2]
    D -->|continue| G[select_final]
    E --> F[fairness_agent_2]
    F --> G
    G --> H[driver_liaison]
    H --> I{has_counter_decisions?}
    I -->|resolve| J[final_resolution]
    I -->|skip| K[explainability]
    J --> K
    K --> L{gemini_enabled?}
    L -->|yes| M[gemini_explain]
    L -->|no| N[END]
    M --> N
```
"""