| """ |
| LangGraph workflow definition for Fair Dispatch allocation. |
| Orchestrates all agents in a graph with conditional edges and checkpointing. |
| """ |
|
|
| import os |
| from typing import Dict, Any, Optional |
| from datetime import datetime |
|
|
| from langgraph.graph import StateGraph, END |
|
|
| from app.schemas.allocation_state import AllocationState |
| from app.services.langgraph_nodes import ( |
| ml_effort_node, |
| route_planner_node, |
| fairness_check_node, |
| fairness_check_2_node, |
| route_planner_reoptimize_node, |
| select_final_proposal_node, |
| driver_liaison_node, |
| final_resolution_node, |
| explainability_node, |
| should_reoptimize, |
| has_counter_decisions, |
| ) |
|
|
|
|
| def create_allocation_graph( |
| checkpointer: Optional[Any] = None, |
| enable_gemini: bool = False, |
| ) -> StateGraph: |
| """ |
| Create the allocation workflow graph. |
| |
| Args: |
| checkpointer: Optional LangGraph checkpointer for state persistence |
| enable_gemini: If True, add Gemini explainability node |
| |
| Returns: |
| Compiled StateGraph ready for invocation |
| """ |
| |
| workflow = StateGraph(AllocationState) |
| |
| |
| |
| |
| |
| |
| workflow.add_node("ml_effort", ml_effort_node) |
| |
| |
| workflow.add_node("route_planner_1", route_planner_node) |
| |
| |
| workflow.add_node("fairness_agent_1", fairness_check_node) |
| |
| |
| workflow.add_node("route_planner_2", route_planner_reoptimize_node) |
| |
| |
| workflow.add_node("fairness_agent_2", fairness_check_2_node) |
| |
| |
| workflow.add_node("select_final", select_final_proposal_node) |
| |
| |
| workflow.add_node("driver_liaison", driver_liaison_node) |
| |
| |
| workflow.add_node("final_resolution", final_resolution_node) |
| |
| |
| workflow.add_node("explainability", explainability_node) |
| |
| |
| if enable_gemini and os.getenv("GOOGLE_API_KEY"): |
| try: |
| from app.services.gemini_explain_node import gemini_explain_node |
| workflow.add_node("gemini_explain", gemini_explain_node) |
| except ImportError: |
| pass |
| |
| |
| |
| |
| |
| |
| workflow.set_entry_point("ml_effort") |
| |
| |
| workflow.add_edge("ml_effort", "route_planner_1") |
| |
| |
| workflow.add_edge("route_planner_1", "fairness_agent_1") |
| |
| |
| workflow.add_conditional_edges( |
| "fairness_agent_1", |
| should_reoptimize, |
| { |
| "reoptimize": "route_planner_2", |
| "continue": "select_final", |
| } |
| ) |
| |
| |
| workflow.add_edge("route_planner_2", "fairness_agent_2") |
| workflow.add_edge("fairness_agent_2", "select_final") |
| |
| |
| workflow.add_edge("select_final", "driver_liaison") |
| |
| |
| workflow.add_conditional_edges( |
| "driver_liaison", |
| has_counter_decisions, |
| { |
| "resolve": "final_resolution", |
| "skip": "explainability", |
| } |
| ) |
| |
| |
| workflow.add_edge("final_resolution", "explainability") |
| |
| |
| if enable_gemini and os.getenv("GOOGLE_API_KEY"): |
| try: |
| from app.services.gemini_explain_node import gemini_explain_node |
| workflow.add_edge("explainability", "gemini_explain") |
| workflow.add_edge("gemini_explain", END) |
| except ImportError: |
| workflow.add_edge("explainability", END) |
| else: |
| workflow.add_edge("explainability", END) |
| |
| |
| |
| |
| |
| if checkpointer: |
| return workflow.compile(checkpointer=checkpointer) |
| else: |
| return workflow.compile() |
|
|
|
|
| |
| _allocation_graph = None |
|
|
|
|
| def clear_allocation_graph() -> None: |
| """Clear the cached allocation graph to force recreation.""" |
| global _allocation_graph |
| _allocation_graph = None |
|
|
|
|
| def get_allocation_graph( |
| checkpointer: Optional[Any] = None, |
| enable_gemini: bool = None, |
| force_recreate: bool = False, |
| ) -> StateGraph: |
| """ |
| Get or create the allocation graph singleton. |
| |
| Args: |
| checkpointer: Optional checkpointer for persistence |
| enable_gemini: Override Gemini setting (defaults to env var) |
| force_recreate: If True, recreate graph even if cached |
| |
| Returns: |
| Compiled allocation graph |
| """ |
| global _allocation_graph |
| |
| if enable_gemini is None: |
| enable_gemini = os.getenv("ENABLE_GEMINI_EXPLAIN", "false").lower() == "true" |
| |
| if _allocation_graph is None or force_recreate: |
| _allocation_graph = create_allocation_graph( |
| checkpointer=checkpointer, |
| enable_gemini=enable_gemini, |
| ) |
| |
| return _allocation_graph |
|
|
|
|
| async def invoke_allocation_workflow( |
| request_dict: Dict[str, Any], |
| config_used: Optional[Dict[str, Any]] = None, |
| driver_models: list = None, |
| route_models: list = None, |
| route_dicts: list = None, |
| driver_contexts: Dict[str, Dict[str, Any]] = None, |
| recovery_targets: Dict[str, Optional[float]] = None, |
| allocation_run_id: Optional[str] = None, |
| thread_id: Optional[str] = None, |
| ) -> AllocationState: |
| """ |
| Invoke the allocation workflow with the given inputs. |
| |
| This is the main entry point for running the LangGraph workflow. |
| |
| Args: |
| request_dict: AllocationRequest.dict() |
| config_used: Active FairnessConfig snapshot |
| driver_models: List of driver model data |
| route_models: List of route model data |
| route_dicts: List of route dictionaries with packages |
| driver_contexts: Dict of driver contexts for liaison agent |
| recovery_targets: Recovery effort targets per driver |
| allocation_run_id: ID of the AllocationRun for persistence |
| thread_id: Thread ID for checkpointing (defaults to allocation_run_id) |
| |
| Returns: |
| Final AllocationState with all agent outputs |
| """ |
| graph = get_allocation_graph(force_recreate=True) |
| |
| |
| initial_state = AllocationState( |
| request=request_dict, |
| config_used=config_used or {}, |
| driver_models=driver_models or [], |
| route_models=route_models or [], |
| route_dicts=route_dicts or [], |
| driver_contexts=driver_contexts or {}, |
| recovery_targets=recovery_targets or {}, |
| allocation_run_id=allocation_run_id, |
| workflow_start=datetime.utcnow(), |
| ) |
| |
| |
| config = {} |
| if thread_id or allocation_run_id: |
| config["configurable"] = {"thread_id": thread_id or allocation_run_id} |
| |
| |
| |
| final_state_dict = await graph.ainvoke(initial_state.model_dump(), config=config) |
| |
| |
| return AllocationState.model_validate(final_state_dict) |
|
|
|
|
| def get_workflow_visualization() -> str: |
| """ |
| Get a Mermaid diagram of the workflow for documentation. |
| |
| Returns: |
| Mermaid diagram string |
| """ |
| return """ |
| ```mermaid |
| graph TD |
| A[Entry: ml_effort] --> B[route_planner_1] |
| B --> C[fairness_agent_1] |
| C --> D{should_reoptimize?} |
| D -->|reoptimize| E[route_planner_2] |
| D -->|continue| G[select_final] |
| E --> F[fairness_agent_2] |
| F --> G |
| G --> H[driver_liaison] |
| H --> I{has_counter_decisions?} |
| I -->|resolve| J[final_resolution] |
| I -->|skip| K[explainability] |
| J --> K |
| K --> L{gemini_enabled?} |
| L -->|yes| M[gemini_explain] |
| L -->|no| N[END] |
| M --> N |
| ``` |
| """ |
|
|