FairRelay / brain /app /services /langgraph_workflow.py
MouleeswaranM's picture
Upload folder using huggingface_hub
fcf8749 verified
raw
history blame
8.9 kB
"""
LangGraph workflow definition for Fair Dispatch allocation.
Orchestrates all agents in a graph with conditional edges and checkpointing.
"""
import os
from typing import Dict, Any, Optional
from datetime import datetime
from langgraph.graph import StateGraph, END
from app.schemas.allocation_state import AllocationState
from app.services.langgraph_nodes import (
ml_effort_node,
route_planner_node,
fairness_check_node,
fairness_check_2_node,
route_planner_reoptimize_node,
select_final_proposal_node,
driver_liaison_node,
final_resolution_node,
explainability_node,
should_reoptimize,
has_counter_decisions,
)
def create_allocation_graph(
checkpointer: Optional[Any] = None,
enable_gemini: bool = False,
) -> StateGraph:
"""
Create the allocation workflow graph.
Args:
checkpointer: Optional LangGraph checkpointer for state persistence
enable_gemini: If True, add Gemini explainability node
Returns:
Compiled StateGraph ready for invocation
"""
# Create the graph with AllocationState
workflow = StateGraph(AllocationState)
# ==========================================================================
# Add Nodes
# ==========================================================================
# Phase 1: ML Effort Agent
workflow.add_node("ml_effort", ml_effort_node)
# Phase 2: Route Planner Agent (Proposal 1)
workflow.add_node("route_planner_1", route_planner_node)
# Phase 3: Fairness Manager Agent (Check 1) - renamed to avoid state key conflict
workflow.add_node("fairness_agent_1", fairness_check_node)
# Phase 3b: Route Planner Re-optimization (Proposal 2)
workflow.add_node("route_planner_2", route_planner_reoptimize_node)
# Phase 3c: Fairness Manager Agent (Check 2) - renamed to avoid state key conflict
workflow.add_node("fairness_agent_2", fairness_check_2_node)
# Phase 3d: Select Final Proposal
workflow.add_node("select_final", select_final_proposal_node)
# Phase 4: Driver Liaison Agent
workflow.add_node("driver_liaison", driver_liaison_node)
# Phase 5: Final Resolution Agent
workflow.add_node("final_resolution", final_resolution_node)
# Phase 6: Explainability Agent
workflow.add_node("explainability", explainability_node)
# Optional: Gemini Explainability Node
if enable_gemini and os.getenv("GOOGLE_API_KEY"):
try:
from app.services.gemini_explain_node import gemini_explain_node
workflow.add_node("gemini_explain", gemini_explain_node)
except ImportError:
pass # Gemini not available, skip
# ==========================================================================
# Add Edges
# ==========================================================================
# Entry point
workflow.set_entry_point("ml_effort")
# Linear flow: ML Effort -> Route Planner 1
workflow.add_edge("ml_effort", "route_planner_1")
# Route Planner 1 -> Fairness Agent 1
workflow.add_edge("route_planner_1", "fairness_agent_1")
# Conditional: Fairness Agent 1 -> Reoptimize or Select Final
workflow.add_conditional_edges(
"fairness_agent_1",
should_reoptimize,
{
"reoptimize": "route_planner_2",
"continue": "select_final",
}
)
# Reoptimize path: Route Planner 2 -> Fairness Agent 2 -> Select Final
workflow.add_edge("route_planner_2", "fairness_agent_2")
workflow.add_edge("fairness_agent_2", "select_final")
# Select Final -> Driver Liaison
workflow.add_edge("select_final", "driver_liaison")
# Conditional: Driver Liaison -> Final Resolution or Explainability
workflow.add_conditional_edges(
"driver_liaison",
has_counter_decisions,
{
"resolve": "final_resolution",
"skip": "explainability",
}
)
# Final Resolution -> Explainability
workflow.add_edge("final_resolution", "explainability")
# Explainability -> Gemini or END
if enable_gemini and os.getenv("GOOGLE_API_KEY"):
try:
from app.services.gemini_explain_node import gemini_explain_node
workflow.add_edge("explainability", "gemini_explain")
workflow.add_edge("gemini_explain", END)
except ImportError:
workflow.add_edge("explainability", END)
else:
workflow.add_edge("explainability", END)
# ==========================================================================
# Compile
# ==========================================================================
if checkpointer:
return workflow.compile(checkpointer=checkpointer)
else:
return workflow.compile()
# Global graph instance (lazy initialization)
_allocation_graph = None
def clear_allocation_graph() -> None:
"""Clear the cached allocation graph to force recreation."""
global _allocation_graph
_allocation_graph = None
def get_allocation_graph(
checkpointer: Optional[Any] = None,
enable_gemini: bool = None,
force_recreate: bool = False,
) -> StateGraph:
"""
Get or create the allocation graph singleton.
Args:
checkpointer: Optional checkpointer for persistence
enable_gemini: Override Gemini setting (defaults to env var)
force_recreate: If True, recreate graph even if cached
Returns:
Compiled allocation graph
"""
global _allocation_graph
if enable_gemini is None:
enable_gemini = os.getenv("ENABLE_GEMINI_EXPLAIN", "false").lower() == "true"
if _allocation_graph is None or force_recreate:
_allocation_graph = create_allocation_graph(
checkpointer=checkpointer,
enable_gemini=enable_gemini,
)
return _allocation_graph
async def invoke_allocation_workflow(
request_dict: Dict[str, Any],
config_used: Optional[Dict[str, Any]] = None,
driver_models: list = None,
route_models: list = None,
route_dicts: list = None,
driver_contexts: Dict[str, Dict[str, Any]] = None,
recovery_targets: Dict[str, Optional[float]] = None,
allocation_run_id: Optional[str] = None,
thread_id: Optional[str] = None,
) -> AllocationState:
"""
Invoke the allocation workflow with the given inputs.
This is the main entry point for running the LangGraph workflow.
Args:
request_dict: AllocationRequest.dict()
config_used: Active FairnessConfig snapshot
driver_models: List of driver model data
route_models: List of route model data
route_dicts: List of route dictionaries with packages
driver_contexts: Dict of driver contexts for liaison agent
recovery_targets: Recovery effort targets per driver
allocation_run_id: ID of the AllocationRun for persistence
thread_id: Thread ID for checkpointing (defaults to allocation_run_id)
Returns:
Final AllocationState with all agent outputs
"""
graph = get_allocation_graph(force_recreate=True) # Force recreate to pick up latest nodes
# Build initial state
initial_state = AllocationState(
request=request_dict,
config_used=config_used or {},
driver_models=driver_models or [],
route_models=route_models or [],
route_dicts=route_dicts or [],
driver_contexts=driver_contexts or {},
recovery_targets=recovery_targets or {},
allocation_run_id=allocation_run_id,
workflow_start=datetime.utcnow(),
)
# Prepare config for graph invocation
config = {}
if thread_id or allocation_run_id:
config["configurable"] = {"thread_id": thread_id or allocation_run_id}
# Invoke the graph
# Note: LangGraph's invoke returns the final state
final_state_dict = await graph.ainvoke(initial_state.model_dump(), config=config)
# Convert back to AllocationState
return AllocationState.model_validate(final_state_dict)
def get_workflow_visualization() -> str:
"""
Get a Mermaid diagram of the workflow for documentation.
Returns:
Mermaid diagram string
"""
return """
```mermaid
graph TD
A[Entry: ml_effort] --> B[route_planner_1]
B --> C[fairness_agent_1]
C --> D{should_reoptimize?}
D -->|reoptimize| E[route_planner_2]
D -->|continue| G[select_final]
E --> F[fairness_agent_2]
F --> G
G --> H[driver_liaison]
H --> I{has_counter_decisions?}
I -->|resolve| J[final_resolution]
I -->|skip| K[explainability]
J --> K
K --> L{gemini_enabled?}
L -->|yes| M[gemini_explain]
L -->|no| N[END]
M --> N
```
"""