FairRelay / brain /app /services /langgraph_nodes.py
MouleeswaranM's picture
Upload folder using huggingface_hub
fcf8749 verified
raw
history blame
36.6 kB
"""
LangGraph node wrappers for Fair Dispatch agents.
Each node wraps an existing agent with minimal changes, preserving the original logic.
"""
from datetime import datetime
from typing import Dict, Any, List, Optional, Tuple
from uuid import UUID
import asyncio
from app.schemas.allocation_state import AllocationState
from app.schemas.agent_schemas import (
FairnessThresholds,
DriverAssignmentProposal,
DriverContext,
)
from app.services.ml_effort_agent import MLEffortAgent
from app.services.route_planner_agent import RoutePlannerAgent
from app.services.fairness_manager_agent import FairnessManagerAgent
from app.services.driver_liaison_agent import DriverLiaisonAgent
from app.services.final_resolution import FinalResolutionAgent
from app.services.explainability import ExplainabilityAgent
from app.schemas.explainability import DriverExplanationInput
from app.services.fairness import calculate_fairness_score
from app.core.events import agent_event_bus, make_agent_event
class ModelWrapper:
"""Helper to wrap dicts as objects for agent compatibility."""
def __init__(self, data: Dict[str, Any]):
self._data = data
def __getattr__(self, name: str) -> Any:
return self._data.get(name)
@property
def is_ev(self) -> bool:
return self._data.get("vehicle_type") == "EV"
def _publish_event_sync(
allocation_run_id: Optional[str],
agent_name: str,
step_type: str,
state: str,
payload: Optional[Dict[str, Any]] = None,
) -> None:
"""
Publish an agent event synchronously (fire-and-forget).
Used by LangGraph nodes which are synchronous functions.
"""
if not allocation_run_id:
return
event = make_agent_event(
allocation_run_id=allocation_run_id,
agent_name=agent_name,
step_type=step_type,
state=state,
payload=payload,
)
# Schedule async publish - get or create event loop
try:
loop = asyncio.get_running_loop()
loop.create_task(agent_event_bus.publish(event))
except RuntimeError:
# No running loop, create one for this publish
asyncio.run(agent_event_bus.publish(event))
def _create_decision_log(
agent_name: str,
step_type: str,
input_snapshot: Dict[str, Any],
output_snapshot: Dict[str, Any],
) -> Dict[str, Any]:
"""Create a decision log entry compatible with DecisionLog model."""
return {
"timestamp": datetime.utcnow().isoformat(),
"agent_name": agent_name,
"step_type": step_type,
"input_snapshot": input_snapshot,
"output_snapshot": output_snapshot,
}
# =============================================================================
# Node 1: ML Effort Agent
# =============================================================================
def ml_effort_node(state: AllocationState) -> Dict[str, Any]:
"""
LangGraph node #1: ML Effort Agent.
Computes effort matrix for all driver-route pairs using MLEffortAgent.
WRAPS EXISTING AGENT - no logic changes.
"""
run_id = state.allocation_run_id
# Publish STARTED event
_publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "STARTED", {
"num_drivers": len(state.driver_models),
"num_routes": len(state.route_models),
})
# Initialize agent
ml_agent = MLEffortAgent()
# Get EV config from state
ev_config = {
"safety_margin_pct": state.config_used.get("ev_safety_margin_pct", 10.0) if state.config_used else 10.0,
"charging_penalty_weight": state.config_used.get("ev_charging_penalty_weight", 0.3) if state.config_used else 0.3,
}
# Wrap dicts as objects for agent compatibility
drivers = [ModelWrapper(d) for d in state.driver_models]
routes = [ModelWrapper(r) for r in state.route_models]
# Compute effort matrix (EXISTING CODE - UNCHANGED)
effort_result = ml_agent.compute_effort_matrix(
drivers=drivers,
routes=routes,
ev_config=ev_config,
)
# Serialize result for state
effort_dict = {
"matrix": effort_result.matrix,
"driver_ids": effort_result.driver_ids,
"route_ids": effort_result.route_ids,
"breakdown": {k: v.model_dump() if hasattr(v, 'model_dump') else v
for k, v in effort_result.breakdown.items()},
"stats": effort_result.stats,
"infeasible_pairs": list(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else [],
}
# Create decision log
log_entry = _create_decision_log(
agent_name="ML_EFFORT",
step_type="MATRIX_GENERATION",
input_snapshot=ml_agent.get_input_snapshot(drivers, routes),
output_snapshot={
**ml_agent.get_output_snapshot(effort_result),
"num_infeasible_ev_pairs": len(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else 0,
},
)
# Publish COMPLETED event
_publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "COMPLETED", {
"min_effort": effort_result.stats.get("min", 0),
"max_effort": effort_result.stats.get("max", 0),
"avg_effort": effort_result.stats.get("avg", 0),
})
return {
"effort_matrix": effort_dict,
"decision_logs": state.decision_logs + [log_entry],
}
# =============================================================================
# Node 2: Route Planner Agent (Proposal 1)
# =============================================================================
def route_planner_node(state: AllocationState) -> Dict[str, Any]:
"""
LangGraph node #2: Route Planner Agent - Proposal 1.
Generates optimal driver-route assignment using OR-Tools.
WRAPS EXISTING AGENT - no logic changes.
"""
run_id = state.allocation_run_id
# Publish STARTED event
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "STARTED", {
"num_drivers": len(state.driver_models),
"num_routes": len(state.route_models),
})
planner_agent = RoutePlannerAgent()
# Reconstruct EffortMatrixResult-like object for planner
from app.schemas.agent_schemas import EffortMatrixResult, EffortBreakdown
# Use stats from serialized state or compute if not available
matrix = state.effort_matrix["matrix"]
stats = state.effort_matrix.get("stats")
if not stats:
all_values = [v for row in matrix for v in row if v < float('inf')]
stats = {
"min": min(all_values) if all_values else 0.0,
"max": max(all_values) if all_values else 0.0,
"avg": sum(all_values) / len(all_values) if all_values else 0.0,
}
effort_result = EffortMatrixResult(
matrix=matrix,
driver_ids=state.effort_matrix["driver_ids"],
route_ids=state.effort_matrix["route_ids"],
breakdown={}, # Simplified - full breakdown not needed for planning
stats=stats,
infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
)
# Get recovery penalty weight
recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
# Wrap dicts as objects for agent compatibility
drivers = [ModelWrapper(d) for d in state.driver_models]
routes = [ModelWrapper(r) for r in state.route_models]
# Generate Proposal 1 (EXISTING CODE - UNCHANGED)
proposal1 = planner_agent.plan(
effort_result=effort_result,
drivers=drivers,
routes=routes,
recovery_targets=state.recovery_targets or {},
recovery_penalty_weight=recovery_penalty_weight,
proposal_number=1,
)
# Serialize result
proposal_dict = {
"allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal1.allocation],
"total_effort": proposal1.total_effort,
"avg_effort": proposal1.avg_effort,
"solver_status": proposal1.solver_status,
"proposal_number": proposal1.proposal_number,
"per_driver_effort": proposal1.per_driver_effort,
}
# Create decision log
log_entry = _create_decision_log(
agent_name="ROUTE_PLANNER",
step_type="PROPOSAL_1",
input_snapshot=planner_agent.get_input_snapshot(effort_result),
output_snapshot=planner_agent.get_output_snapshot(proposal1),
)
# Publish COMPLETED event
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "COMPLETED", {
"total_effort": proposal1.total_effort,
"num_assignments": len(proposal1.allocation),
"solver_status": proposal1.solver_status,
})
return {
"route_proposal_1": proposal_dict,
"decision_logs": state.decision_logs + [log_entry],
}
# =============================================================================
# Node 3: Fairness Check Agent
# =============================================================================
def fairness_check_node(state: AllocationState) -> Dict[str, Any]:
"""
LangGraph node #3: Fairness Manager Agent.
Evaluates fairness metrics and decides ACCEPT or REOPTIMIZE.
WRAPS EXISTING AGENT - no logic changes.
"""
run_id = state.allocation_run_id
proposal_number = 2 if state.route_proposal_2 else 1
# Publish STARTED event
_publish_event_sync(run_id, "FAIRNESS_MANAGER", f"FAIRNESS_CHECK_{proposal_number}", "STARTED", {
"proposal_number": proposal_number,
})
# Get thresholds from config
thresholds = FairnessThresholds(
gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
max_gap_threshold=state.config_used.get("max_gap_threshold", 25.0) if state.config_used else 25.0,
)
fairness_agent = FairnessManagerAgent(thresholds=thresholds)
# Reconstruct RoutePlanResult for fairness check
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
# Determine which proposal to check
proposal_to_check = state.route_proposal_2 or state.route_proposal_1
plan_result = RoutePlanResult(
allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
total_effort=proposal_to_check["total_effort"],
avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / len(proposal_to_check["allocation"]) if proposal_to_check["allocation"] else 0.0),
solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
proposal_number=proposal_number,
per_driver_effort=proposal_to_check["per_driver_effort"],
)
# Check fairness (EXISTING CODE - UNCHANGED)
fairness_result = fairness_agent.check(plan_result, proposal_number=proposal_number)
# Serialize result
fairness_dict = {
"status": fairness_result.status,
"proposal_number": fairness_result.proposal_number,
"metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
"avg_effort": fairness_result.metrics.avg_effort,
"std_dev": fairness_result.metrics.std_dev,
"gini_index": fairness_result.metrics.gini_index,
"max_effort": fairness_result.metrics.max_effort,
"min_effort": fairness_result.metrics.min_effort,
"max_gap": fairness_result.metrics.max_gap,
},
"recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
}
# Create decision log
log_entry = _create_decision_log(
agent_name="FAIRNESS_MANAGER",
step_type=f"FAIRNESS_CHECK_PROPOSAL_{proposal_number}",
input_snapshot=fairness_agent.get_input_snapshot(plan_result),
output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
)
# Publish COMPLETED event
_publish_event_sync(run_id, "FAIRNESS_MANAGER", f"FAIRNESS_CHECK_{proposal_number}", "COMPLETED", {
"status": fairness_result.status,
"gini_index": fairness_dict["metrics"]["gini_index"],
"std_dev": fairness_dict["metrics"]["std_dev"],
})
# Update appropriate check result based on proposal number
updates = {
"decision_logs": state.decision_logs + [log_entry],
}
if proposal_number == 1:
updates["fairness_check_1"] = fairness_dict
else:
updates["fairness_check_2"] = fairness_dict
return updates
def fairness_check_2_node(state: AllocationState) -> Dict[str, Any]:
"""
LangGraph node for second fairness check (dedicated wrapper).
This is a separate function to avoid LangGraph state key conflicts
when the same function is used for multiple nodes.
"""
run_id = state.allocation_run_id
proposal_number = 2 # Always proposal 2 for this node
# Publish STARTED event
_publish_event_sync(run_id, "FAIRNESS_MANAGER", f"FAIRNESS_CHECK_{proposal_number}", "STARTED", {
"proposal_number": proposal_number,
})
# Get thresholds from config
thresholds = FairnessThresholds(
gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
max_gap_threshold=state.config_used.get("max_gap_threshold", 25.0) if state.config_used else 25.0,
)
fairness_agent = FairnessManagerAgent(thresholds=thresholds)
# Reconstruct RoutePlanResult for fairness check
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
# Always check proposal 2 for this node
proposal_to_check = state.route_proposal_2
plan_result = RoutePlanResult(
allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
total_effort=proposal_to_check["total_effort"],
avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / len(proposal_to_check["allocation"]) if proposal_to_check["allocation"] else 0.0),
solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
proposal_number=proposal_number,
per_driver_effort=proposal_to_check["per_driver_effort"],
)
# Check fairness
fairness_result = fairness_agent.check(plan_result, proposal_number=proposal_number)
# Serialize result
fairness_dict = {
"status": fairness_result.status,
"proposal_number": fairness_result.proposal_number,
"metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
"avg_effort": fairness_result.metrics.avg_effort,
"std_dev": fairness_result.metrics.std_dev,
"gini_index": fairness_result.metrics.gini_index,
"max_effort": fairness_result.metrics.max_effort,
"min_effort": fairness_result.metrics.min_effort,
"max_gap": fairness_result.metrics.max_gap,
},
"recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
}
# Create decision log
log_entry = _create_decision_log(
agent_name="FAIRNESS_MANAGER",
step_type=f"FAIRNESS_CHECK_PROPOSAL_{proposal_number}",
input_snapshot=fairness_agent.get_input_snapshot(plan_result),
output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
)
# Publish COMPLETED event
_publish_event_sync(run_id, "FAIRNESS_MANAGER", f"FAIRNESS_CHECK_{proposal_number}", "COMPLETED", {
"status": fairness_result.status,
"gini_index": fairness_dict["metrics"]["gini_index"],
"std_dev": fairness_dict["metrics"]["std_dev"],
})
return {
"fairness_check_2": fairness_dict,
"decision_logs": state.decision_logs + [log_entry],
}
# =============================================================================
# Node 4: Route Planner Agent (Proposal 2 - with fairness penalties)
# =============================================================================
def route_planner_reoptimize_node(state: AllocationState) -> Dict[str, Any]:
"""
LangGraph node #4: Route Planner Agent - Proposal 2 (re-optimization).
Re-runs OR-Tools with fairness penalties applied.
WRAPS EXISTING AGENT - no logic changes.
"""
run_id = state.allocation_run_id
# Publish STARTED event
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "STARTED", {
"reason": "fairness_reoptimization",
})
planner_agent = RoutePlannerAgent()
# Reconstruct effort result
from app.schemas.agent_schemas import EffortMatrixResult, FairnessRecommendations
# Use stats from serialized state or compute if not available
matrix = state.effort_matrix["matrix"]
stats = state.effort_matrix.get("stats")
if not stats:
all_values = [v for row in matrix for v in row if v < float('inf')]
stats = {
"min": min(all_values) if all_values else 0.0,
"max": max(all_values) if all_values else 0.0,
"avg": sum(all_values) / len(all_values) if all_values else 0.0,
}
effort_result = EffortMatrixResult(
matrix=matrix,
driver_ids=state.effort_matrix["driver_ids"],
route_ids=state.effort_matrix["route_ids"],
breakdown={},
stats=stats,
infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
)
# Build penalties from recommendations
recommendations_dict = state.fairness_check_1.get("recommendations")
penalties = {}
if recommendations_dict:
recommendations = FairnessRecommendations(**recommendations_dict)
penalties = planner_agent.build_penalties_from_recommendations(
recommendations,
state.route_proposal_1["per_driver_effort"],
)
# Get recovery settings
recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
# Wrap dicts as objects for agent compatibility
drivers = [ModelWrapper(d) for d in state.driver_models]
routes = [ModelWrapper(r) for r in state.route_models]
# Generate Proposal 2 (EXISTING CODE - UNCHANGED)
proposal2 = planner_agent.plan(
effort_result=effort_result,
drivers=drivers,
routes=routes,
fairness_penalties=penalties,
recovery_targets=state.recovery_targets or {},
recovery_penalty_weight=recovery_penalty_weight,
proposal_number=2,
)
# Serialize result
proposal_dict = {
"allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal2.allocation],
"total_effort": proposal2.total_effort,
"avg_effort": proposal2.avg_effort,
"solver_status": proposal2.solver_status,
"proposal_number": proposal2.proposal_number,
"per_driver_effort": proposal2.per_driver_effort,
}
# Create decision log
log_entry = _create_decision_log(
agent_name="ROUTE_PLANNER",
step_type="PROPOSAL_2",
input_snapshot=planner_agent.get_input_snapshot(effort_result, penalties),
output_snapshot=planner_agent.get_output_snapshot(proposal2),
)
# Publish COMPLETED event
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "COMPLETED", {
"total_effort": proposal2.total_effort,
"num_assignments": len(proposal2.allocation),
"solver_status": proposal2.solver_status,
})
return {
"route_proposal_2": proposal_dict,
"decision_logs": state.decision_logs + [log_entry],
}
# =============================================================================
# Node 5: Select Final Proposal
# =============================================================================
def select_final_proposal_node(state: AllocationState) -> Dict[str, Any]:
"""
Select the final proposal after fairness checks.
If proposal 2 exists and has better fairness, use it.
Otherwise, use proposal 1.
"""
final_proposal = state.route_proposal_1
final_fairness = state.fairness_check_1
if state.route_proposal_2 and state.fairness_check_2:
# Compare fairness metrics
check1_metrics = state.fairness_check_1["metrics"]
check2_metrics = state.fairness_check_2["metrics"]
# Use proposal 2 if it improves fairness
if (check2_metrics["gini_index"] <= check1_metrics["gini_index"] or
check2_metrics["max_gap"] < check1_metrics["max_gap"]):
final_proposal = state.route_proposal_2
final_fairness = state.fairness_check_2
return {
"final_proposal": final_proposal,
"final_fairness": final_fairness,
"final_per_driver_effort": final_proposal["per_driver_effort"],
}
# =============================================================================
# Node 6: Driver Liaison Agent
# =============================================================================
def driver_liaison_node(state: AllocationState) -> Dict[str, Any]:
"""
LangGraph node #6: Driver Liaison Agent.
Reviews proposed assignments and makes ACCEPT/COUNTER decisions per driver.
WRAPS EXISTING AGENT - no logic changes.
"""
run_id = state.allocation_run_id
# Publish STARTED event
_publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "STARTED", {
"num_drivers": len(state.driver_models),
})
from app.schemas.agent_schemas import AllocationItem
liaison_agent = DriverLiaisonAgent()
final_proposal = state.final_proposal or state.route_proposal_1
final_fairness = state.final_fairness or state.fairness_check_1
# Build DriverAssignmentProposals with ranking
sorted_allocations = sorted(
final_proposal["allocation"],
key=lambda x: x["effort"],
reverse=True # Highest effort = rank 1
)
driver_proposals: List[DriverAssignmentProposal] = []
for rank, alloc_item in enumerate(sorted_allocations, start=1):
driver_proposals.append(DriverAssignmentProposal(
driver_id=str(alloc_item["driver_id"]),
route_id=str(alloc_item["route_id"]),
effort=alloc_item["effort"],
rank_in_team=rank,
))
# Get global metrics
metrics = final_fairness["metrics"]
global_avg_effort = metrics["avg_effort"]
global_std_effort = metrics["std_dev"]
# Build DriverContext objects
driver_context_objs: Dict[str, DriverContext] = {}
for driver_id, context_dict in state.driver_contexts.items():
driver_context_objs[driver_id] = DriverContext(**context_dict)
# Run liaison for all drivers (EXISTING CODE - UNCHANGED)
negotiation_result = liaison_agent.run_for_all_drivers(
proposals=driver_proposals,
driver_contexts=driver_context_objs,
effort_matrix=state.effort_matrix["matrix"],
driver_ids=state.effort_matrix["driver_ids"],
route_ids=state.effort_matrix["route_ids"],
global_avg_effort=global_avg_effort,
global_std_effort=global_std_effort,
)
# Serialize result
liaison_dict = {
"decisions": [d.model_dump() if hasattr(d, 'model_dump') else d for d in negotiation_result.decisions],
"num_accept": negotiation_result.num_accept,
"num_counter": negotiation_result.num_counter,
"num_force_accept": negotiation_result.num_force_accept,
}
# Create decision log
log_entry = _create_decision_log(
agent_name="DRIVER_LIAISON",
step_type="NEGOTIATION_DECISIONS",
input_snapshot=liaison_agent.get_input_snapshot(
driver_proposals,
global_avg_effort,
global_std_effort,
),
output_snapshot=liaison_agent.get_output_snapshot(negotiation_result),
)
# Publish COMPLETED event
_publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "COMPLETED", {
"num_accept": negotiation_result.num_accept,
"num_counter": negotiation_result.num_counter,
"num_force_accept": negotiation_result.num_force_accept,
})
return {
"liaison_feedback": liaison_dict,
"decision_logs": state.decision_logs + [log_entry],
}
# =============================================================================
# Node 7: Final Resolution Agent
# =============================================================================
def final_resolution_node(state: AllocationState) -> Dict[str, Any]:
"""
LangGraph node #7: Final Resolution Agent.
Resolves COUNTER decisions through swaps.
WRAPS EXISTING AGENT - no logic changes.
"""
run_id = state.allocation_run_id
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem, FairnessMetrics, DriverLiaisonDecision
# Check if there are any COUNTER decisions
counter_decisions = [
d for d in state.liaison_feedback["decisions"]
if d["decision"] == "COUNTER"
]
if not counter_decisions:
# Publish SKIPPED event
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {
"reason": "no_counters",
"swaps_applied": 0,
})
# No resolution needed
return {
"resolution_result": {"swaps_applied": []},
}
# Publish STARTED event
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "STARTED", {
"num_counters": len(counter_decisions),
})
resolution_agent = FinalResolutionAgent()
# Reconstruct objects for resolution
final_proposal = state.final_proposal or state.route_proposal_1
final_fairness = state.final_fairness or state.fairness_check_1
approved_proposal = RoutePlanResult(
allocation=[AllocationItem(**a) for a in final_proposal["allocation"]],
total_effort=final_proposal["total_effort"],
avg_effort=final_proposal.get("avg_effort", final_proposal["total_effort"] / len(final_proposal["allocation"]) if final_proposal["allocation"] else 0.0),
solver_status=final_proposal.get("solver_status", "OPTIMAL"),
proposal_number=final_proposal["proposal_number"],
per_driver_effort=final_proposal["per_driver_effort"],
)
decisions = [DriverLiaisonDecision(**d) for d in state.liaison_feedback["decisions"]]
current_metrics = FairnessMetrics(**final_fairness["metrics"])
# Resolve counters (EXISTING CODE - UNCHANGED)
resolution_result = resolution_agent.resolve_counters(
approved_proposal=approved_proposal,
decisions=decisions,
effort_matrix=state.effort_matrix["matrix"],
driver_ids=state.effort_matrix["driver_ids"],
route_ids=state.effort_matrix["route_ids"],
current_metrics=current_metrics,
)
# Serialize result
resolution_dict = {
"swaps_applied": [s.model_dump() if hasattr(s, 'model_dump') else s for s in resolution_result.swaps_applied],
"allocation": resolution_result.allocation, # Already a list of dicts
"per_driver_effort": resolution_result.per_driver_effort,
"metrics": resolution_result.metrics,
}
# Create decision log
log_entry = _create_decision_log(
agent_name="ROUTE_PLANNER",
step_type="FINAL_RESOLUTION",
input_snapshot=resolution_agent.get_input_snapshot(
len(counter_decisions),
current_metrics,
final_fairness["metrics"]["avg_effort"],
),
output_snapshot=resolution_agent.get_output_snapshot(resolution_result),
)
# Publish COMPLETED event
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {
"swaps_applied": len(resolution_result.swaps_applied),
})
# Update per-driver effort if swaps were applied
updated_effort = state.final_per_driver_effort.copy()
if resolution_result.swaps_applied:
updated_effort = resolution_result.per_driver_effort
return {
"resolution_result": resolution_dict,
"final_per_driver_effort": updated_effort,
"decision_logs": state.decision_logs + [log_entry],
}
# =============================================================================
# Node 8: Explainability Agent
# =============================================================================
def explainability_node(state: AllocationState) -> Dict[str, Any]:
"""
LangGraph node #8: Explainability Agent.
Generates template-based explanations for each driver.
WRAPS EXISTING AGENT - no logic changes.
"""
run_id = state.allocation_run_id
# Publish STARTED event
_publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "STARTED", {
"num_drivers": len(state.driver_models),
})
explain_agent = ExplainabilityAgent()
final_proposal = state.final_proposal or state.route_proposal_1
final_fairness = state.final_fairness or state.fairness_check_1
final_per_driver_effort = state.final_per_driver_effort or final_proposal["per_driver_effort"]
metrics = final_fairness["metrics"]
avg_effort = metrics["avg_effort"]
# Build lookup structures
route_by_id = {str(r["id"]): r for r in state.route_models}
driver_by_id = {str(d["id"]): d for d in state.driver_models}
route_dict_by_id = {str(r["id"]): rd for r, rd in zip(state.route_models, state.route_dicts)}
# Compute per-driver ranks
sorted_efforts = sorted(
final_per_driver_effort.items(),
key=lambda x: x[1],
reverse=True
)
rank_by_driver = {did: idx + 1 for idx, (did, _) in enumerate(sorted_efforts)}
num_drivers = len(final_per_driver_effort)
# Build liaison decisions lookup
liaison_by_driver = {}
if state.liaison_feedback:
for decision in state.liaison_feedback["decisions"]:
liaison_by_driver[decision["driver_id"]] = decision
# Build swaps lookup
swapped_drivers = set()
if state.resolution_result and state.resolution_result.get("swaps_applied"):
for swap in state.resolution_result["swaps_applied"]:
swapped_drivers.add(swap["driver_a"])
swapped_drivers.add(swap["driver_b"])
explanations: Dict[str, Dict[str, Any]] = {}
category_counts: Dict[str, int] = {}
for alloc_item in final_proposal["allocation"]:
driver_id_str = str(alloc_item["driver_id"])
route_id_str = str(alloc_item["route_id"])
driver = driver_by_id.get(driver_id_str, {})
route = route_by_id.get(route_id_str, {})
route_dict = route_dict_by_id.get(route_id_str, {})
# Use resolved effort if available
effort = final_per_driver_effort.get(driver_id_str, alloc_item["effort"])
fairness_score = calculate_fairness_score(effort, avg_effort)
# Get driver context
driver_context = state.driver_contexts.get(driver_id_str, {})
history_efforts = [driver_context.get("recent_avg_effort", avg_effort)] if driver_context else []
history_hard_days = driver_context.get("recent_hard_days", 0) if driver_context else 0
# Get effort breakdown
breakdown_key = f"{driver_id_str}:{route_id_str}"
effort_breakdown_data = state.effort_matrix.get("breakdown", {}).get(breakdown_key, {})
effort_breakdown = {
"physical_effort": effort_breakdown_data.get("physical_effort", 0),
"route_complexity": effort_breakdown_data.get("route_complexity", 0),
"time_pressure": effort_breakdown_data.get("time_pressure", 0),
}
# Get liaison decision
liaison_decision = liaison_by_driver.get(driver_id_str)
# Determine if recovery day
is_recovery = (
history_hard_days >= 3 and
effort < avg_effort * 0.85
)
# Build explanation input
explain_input = DriverExplanationInput(
driver_id=driver_id_str,
driver_name=driver.get("name", "Driver"),
num_drivers=num_drivers,
today_effort=effort,
today_rank=rank_by_driver.get(driver_id_str, num_drivers),
route_id=route_id_str,
route_summary={
"num_packages": route.get("num_packages", 0),
"total_weight_kg": route.get("total_weight_kg", 0),
"num_stops": route.get("num_stops", 0),
"difficulty_score": route.get("route_difficulty_score", 0),
"estimated_time_minutes": route.get("estimated_time_minutes", 0),
},
effort_breakdown=effort_breakdown,
global_avg_effort=avg_effort,
global_std_effort=metrics["std_dev"],
global_gini_index=metrics["gini_index"],
global_max_gap=metrics["max_gap"],
history_efforts_last_7_days=history_efforts,
history_hard_days_last_7=history_hard_days,
is_recovery_day=is_recovery,
had_manual_override=False, # TODO: Query DB if needed
liaison_decision=liaison_decision["decision"] if liaison_decision else None,
swap_applied=driver_id_str in swapped_drivers,
)
# Generate explanations (EXISTING CODE - UNCHANGED)
explain_output = explain_agent.build_explanation_for_driver(explain_input)
# Track category counts
category_counts[explain_output.category] = category_counts.get(explain_output.category, 0) + 1
explanations[driver_id_str] = {
"driver_explanation": explain_output.driver_explanation,
"admin_explanation": explain_output.admin_explanation,
"category": explain_output.category,
}
# Create decision log
log_entry = _create_decision_log(
agent_name="EXPLAINABILITY",
step_type="EXPLANATIONS_GENERATED",
input_snapshot=explain_agent.get_input_snapshot(
num_drivers=num_drivers,
avg_effort=avg_effort,
std_effort=metrics["std_dev"],
gini_index=metrics["gini_index"],
category_counts=category_counts,
),
output_snapshot=explain_agent.get_output_snapshot(
total_explanations=len(explanations),
category_counts=category_counts,
),
)
# Publish COMPLETED event
_publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "COMPLETED", {
"total_explanations": len(explanations),
"categories": category_counts,
})
return {
"explanations": explanations,
"decision_logs": state.decision_logs + [log_entry],
}
# =============================================================================
# Conditional Edge Functions
# =============================================================================
def should_reoptimize(state: AllocationState) -> str:
"""
Conditional edge: decide if re-optimization is needed.
Returns:
"reoptimize" - if fairness check 1 says REOPTIMIZE and no proposal 2 yet
"continue" - otherwise
"""
if state.fairness_check_1 and state.fairness_check_1.get("status") == "REOPTIMIZE":
if not state.route_proposal_2:
return "reoptimize"
return "continue"
def has_counter_decisions(state: AllocationState) -> str:
"""
Conditional edge: check if any COUNTER decisions need resolution.
Returns:
"resolve" - if there are COUNTER decisions
"skip" - otherwise
"""
if state.liaison_feedback:
counter_count = sum(
1 for d in state.liaison_feedback["decisions"]
if d["decision"] == "COUNTER"
)
if counter_count > 0:
return "resolve"
return "skip"