FairRelay / brain /app /services /langgraph_nodes.py
MouleeswaranM's picture
fix: ModelWrapper.is_ev enum comparison + improve event publish reliability for real-time SSE (#20)
51583f1
"""
LangGraph node wrappers for Fair Dispatch agents.
Each node wraps an existing agent with minimal changes, preserving the original logic.
PRODUCTION FIXES APPLIED:
- ModelWrapper.is_ev: handles "EV", "ELECTRIC", VehicleType.EV enum values
- _publish_event_sync: improved reliability with asyncio.ensure_future
"""
from datetime import datetime
from typing import Dict, Any, List, Optional, Tuple
from uuid import UUID
import asyncio
import logging
from app.schemas.allocation_state import AllocationState
from app.schemas.agent_schemas import (
FairnessThresholds,
DriverAssignmentProposal,
DriverContext,
)
from app.services.ml_effort_agent import MLEffortAgent
from app.services.route_planner_agent import RoutePlannerAgent
from app.services.fairness_manager_agent import FairnessManagerAgent
from app.services.driver_liaison_agent import DriverLiaisonAgent
from app.services.final_resolution import FinalResolutionAgent
from app.services.explainability import ExplainabilityAgent
from app.schemas.explainability import DriverExplanationInput
from app.services.fairness import calculate_fairness_score
from app.core.events import agent_event_bus, make_agent_event
logger = logging.getLogger("fairrelay.langgraph")
class ModelWrapper:
"""Helper to wrap dicts as objects for agent compatibility."""
def __init__(self, data: Dict[str, Any]):
self._data = data
def __getattr__(self, name: str) -> Any:
return self._data.get(name)
@property
def is_ev(self) -> bool:
"""Check if driver has an EV - handles all possible enum/string formats."""
vt = self._data.get("vehicle_type", "")
vt_str = str(vt).upper()
return vt_str in ("EV", "ELECTRIC", "VEHICLETYPE.EV")
def _publish_event_sync(
allocation_run_id: Optional[str],
agent_name: str,
step_type: str,
state: str,
payload: Optional[Dict[str, Any]] = None,
) -> None:
"""
Publish an agent event synchronously (fire-and-forget).
Used by LangGraph nodes which are synchronous functions.
Uses asyncio.ensure_future for reliable delivery when a loop is running.
"""
if not allocation_run_id:
return
event = make_agent_event(
allocation_run_id=allocation_run_id,
agent_name=agent_name,
step_type=step_type,
state=state,
payload=payload,
)
# Schedule async publish on the running event loop
try:
loop = asyncio.get_running_loop()
asyncio.ensure_future(agent_event_bus.publish(event), loop=loop)
except RuntimeError:
# No running loop — this shouldn't happen in FastAPI context
# but handle gracefully for testing
try:
asyncio.run(agent_event_bus.publish(event))
except Exception as e:
logger.warning(f"Failed to publish agent event: {e}")
def _create_decision_log(
agent_name: str,
step_type: str,
input_snapshot: Dict[str, Any],
output_snapshot: Dict[str, Any],
) -> Dict[str, Any]:
"""Create a decision log entry compatible with DecisionLog model."""
return {
"timestamp": datetime.utcnow().isoformat(),
"agent_name": agent_name,
"step_type": step_type,
"input_snapshot": input_snapshot,
"output_snapshot": output_snapshot,
}
# =============================================================================
# Node 1: ML Effort Agent
# =============================================================================
def ml_effort_node(state: AllocationState) -> Dict[str, Any]:
"""
LangGraph node #1: ML Effort Agent.
Computes effort matrix for all driver-route pairs using MLEffortAgent.
"""
run_id = state.allocation_run_id
_publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "STARTED", {
"num_drivers": len(state.driver_models),
"num_routes": len(state.route_models),
})
ml_agent = MLEffortAgent()
ev_config = {
"safety_margin_pct": state.config_used.get("ev_safety_margin_pct", 10.0) if state.config_used else 10.0,
"charging_penalty_weight": state.config_used.get("ev_charging_penalty_weight", 0.3) if state.config_used else 0.3,
}
drivers = [ModelWrapper(d) for d in state.driver_models]
routes = [ModelWrapper(r) for r in state.route_models]
effort_result = ml_agent.compute_effort_matrix(drivers=drivers, routes=routes, ev_config=ev_config)
effort_dict = {
"matrix": effort_result.matrix,
"driver_ids": effort_result.driver_ids,
"route_ids": effort_result.route_ids,
"breakdown": {k: v.model_dump() if hasattr(v, 'model_dump') else v for k, v in effort_result.breakdown.items()},
"stats": effort_result.stats,
"infeasible_pairs": list(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else [],
}
log_entry = _create_decision_log(
agent_name="ML_EFFORT", step_type="MATRIX_GENERATION",
input_snapshot=ml_agent.get_input_snapshot(drivers, routes),
output_snapshot={**ml_agent.get_output_snapshot(effort_result), "num_infeasible_ev_pairs": len(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else 0},
)
_publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "COMPLETED", {
"min_effort": effort_result.stats.get("min", 0),
"max_effort": effort_result.stats.get("max", 0),
"avg_effort": effort_result.stats.get("avg", 0),
})
return {"effort_matrix": effort_dict, "decision_logs": state.decision_logs + [log_entry]}
# =============================================================================
# Node 2: Route Planner Agent (Proposal 1)
# =============================================================================
def route_planner_node(state: AllocationState) -> Dict[str, Any]:
"""LangGraph node #2: Route Planner Agent - Proposal 1 (OR-Tools optimization)."""
run_id = state.allocation_run_id
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "STARTED", {
"num_drivers": len(state.driver_models), "num_routes": len(state.route_models),
})
planner_agent = RoutePlannerAgent()
from app.schemas.agent_schemas import EffortMatrixResult
matrix = state.effort_matrix["matrix"]
stats = state.effort_matrix.get("stats") or {"min": 0, "max": 0, "avg": 0}
effort_result = EffortMatrixResult(
matrix=matrix, driver_ids=state.effort_matrix["driver_ids"],
route_ids=state.effort_matrix["route_ids"], breakdown={}, stats=stats,
infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
)
recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
drivers = [ModelWrapper(d) for d in state.driver_models]
routes = [ModelWrapper(r) for r in state.route_models]
proposal1 = planner_agent.plan(
effort_result=effort_result, drivers=drivers, routes=routes,
recovery_targets=state.recovery_targets or {},
recovery_penalty_weight=recovery_penalty_weight, proposal_number=1,
)
proposal_dict = {
"allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal1.allocation],
"total_effort": proposal1.total_effort, "avg_effort": proposal1.avg_effort,
"solver_status": proposal1.solver_status, "proposal_number": proposal1.proposal_number,
"per_driver_effort": proposal1.per_driver_effort,
}
log_entry = _create_decision_log(
agent_name="ROUTE_PLANNER", step_type="PROPOSAL_1",
input_snapshot=planner_agent.get_input_snapshot(effort_result),
output_snapshot=planner_agent.get_output_snapshot(proposal1),
)
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "COMPLETED", {
"total_effort": proposal1.total_effort, "num_assignments": len(proposal1.allocation),
"solver_status": proposal1.solver_status,
})
return {"route_proposal_1": proposal_dict, "decision_logs": state.decision_logs + [log_entry]}
# =============================================================================
# Node 3: Fairness Check Agent
# =============================================================================
def fairness_check_node(state: AllocationState) -> Dict[str, Any]:
"""LangGraph node #3: Fairness Manager Agent — evaluates Gini/stddev/max_gap."""
run_id = state.allocation_run_id
_publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_1", "STARTED", {"proposal_number": 1})
thresholds = FairnessThresholds(
gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
max_gap_threshold=state.config_used.get("max_gap_threshold", 25.0) if state.config_used else 25.0,
)
fairness_agent = FairnessManagerAgent(thresholds=thresholds)
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
proposal_to_check = state.route_proposal_1
plan_result = RoutePlanResult(
allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
total_effort=proposal_to_check["total_effort"],
avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / max(len(proposal_to_check["allocation"]), 1)),
solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
proposal_number=1, per_driver_effort=proposal_to_check["per_driver_effort"],
)
fairness_result = fairness_agent.check(plan_result, proposal_number=1)
fairness_dict = {
"status": fairness_result.status, "proposal_number": fairness_result.proposal_number,
"metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
"avg_effort": fairness_result.metrics.avg_effort, "std_dev": fairness_result.metrics.std_dev,
"gini_index": fairness_result.metrics.gini_index, "max_effort": fairness_result.metrics.max_effort,
"min_effort": fairness_result.metrics.min_effort, "max_gap": fairness_result.metrics.max_gap,
},
"recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
}
log_entry = _create_decision_log(
agent_name="FAIRNESS_MANAGER", step_type="FAIRNESS_CHECK_PROPOSAL_1",
input_snapshot=fairness_agent.get_input_snapshot(plan_result),
output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
)
_publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_1", "COMPLETED", {
"status": fairness_result.status, "gini_index": fairness_dict["metrics"]["gini_index"],
})
return {"fairness_check_1": fairness_dict, "decision_logs": state.decision_logs + [log_entry]}
def fairness_check_2_node(state: AllocationState) -> Dict[str, Any]:
"""LangGraph node for second fairness check (after re-optimization)."""
run_id = state.allocation_run_id
_publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_2", "STARTED", {"proposal_number": 2})
thresholds = FairnessThresholds(
gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
max_gap_threshold=state.config_used.get("max_gap_threshold", 25.0) if state.config_used else 25.0,
)
fairness_agent = FairnessManagerAgent(thresholds=thresholds)
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
proposal_to_check = state.route_proposal_2
plan_result = RoutePlanResult(
allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
total_effort=proposal_to_check["total_effort"],
avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / max(len(proposal_to_check["allocation"]), 1)),
solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
proposal_number=2, per_driver_effort=proposal_to_check["per_driver_effort"],
)
fairness_result = fairness_agent.check(plan_result, proposal_number=2)
fairness_dict = {
"status": fairness_result.status, "proposal_number": 2,
"metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
"avg_effort": fairness_result.metrics.avg_effort, "std_dev": fairness_result.metrics.std_dev,
"gini_index": fairness_result.metrics.gini_index, "max_effort": fairness_result.metrics.max_effort,
"min_effort": fairness_result.metrics.min_effort, "max_gap": fairness_result.metrics.max_gap,
},
"recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
}
log_entry = _create_decision_log(
agent_name="FAIRNESS_MANAGER", step_type="FAIRNESS_CHECK_PROPOSAL_2",
input_snapshot=fairness_agent.get_input_snapshot(plan_result),
output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
)
_publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_2", "COMPLETED", {
"status": fairness_result.status, "gini_index": fairness_dict["metrics"]["gini_index"],
})
return {"fairness_check_2": fairness_dict, "decision_logs": state.decision_logs + [log_entry]}
# =============================================================================
# Node 4: Route Planner Re-optimization (Proposal 2)
# =============================================================================
def route_planner_reoptimize_node(state: AllocationState) -> Dict[str, Any]:
"""LangGraph node #4: Route Planner - Proposal 2 with fairness penalties."""
run_id = state.allocation_run_id
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "STARTED", {"reason": "fairness_reoptimization"})
planner_agent = RoutePlannerAgent()
from app.schemas.agent_schemas import EffortMatrixResult, FairnessRecommendations
matrix = state.effort_matrix["matrix"]
stats = state.effort_matrix.get("stats") or {"min": 0, "max": 0, "avg": 0}
effort_result = EffortMatrixResult(
matrix=matrix, driver_ids=state.effort_matrix["driver_ids"],
route_ids=state.effort_matrix["route_ids"], breakdown={}, stats=stats,
infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
)
recommendations_dict = state.fairness_check_1.get("recommendations")
penalties = {}
if recommendations_dict:
recommendations = FairnessRecommendations(**recommendations_dict)
penalties = planner_agent.build_penalties_from_recommendations(recommendations, state.route_proposal_1["per_driver_effort"])
recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
drivers = [ModelWrapper(d) for d in state.driver_models]
routes = [ModelWrapper(r) for r in state.route_models]
proposal2 = planner_agent.plan(
effort_result=effort_result, drivers=drivers, routes=routes,
fairness_penalties=penalties, recovery_targets=state.recovery_targets or {},
recovery_penalty_weight=recovery_penalty_weight, proposal_number=2,
)
proposal_dict = {
"allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal2.allocation],
"total_effort": proposal2.total_effort, "avg_effort": proposal2.avg_effort,
"solver_status": proposal2.solver_status, "proposal_number": 2,
"per_driver_effort": proposal2.per_driver_effort,
}
log_entry = _create_decision_log(
agent_name="ROUTE_PLANNER", step_type="PROPOSAL_2",
input_snapshot=planner_agent.get_input_snapshot(effort_result, penalties),
output_snapshot=planner_agent.get_output_snapshot(proposal2),
)
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "COMPLETED", {
"total_effort": proposal2.total_effort, "solver_status": proposal2.solver_status,
})
return {"route_proposal_2": proposal_dict, "decision_logs": state.decision_logs + [log_entry]}
# =============================================================================
# Node 5: Select Final Proposal
# =============================================================================
def select_final_proposal_node(state: AllocationState) -> Dict[str, Any]:
"""Select best proposal based on fairness metrics comparison."""
final_proposal = state.route_proposal_1
final_fairness = state.fairness_check_1
if state.route_proposal_2 and state.fairness_check_2:
check1_metrics = state.fairness_check_1["metrics"]
check2_metrics = state.fairness_check_2["metrics"]
if (check2_metrics["gini_index"] <= check1_metrics["gini_index"] or
check2_metrics["max_gap"] < check1_metrics["max_gap"]):
final_proposal = state.route_proposal_2
final_fairness = state.fairness_check_2
return {"final_proposal": final_proposal, "final_fairness": final_fairness, "final_per_driver_effort": final_proposal["per_driver_effort"]}
# =============================================================================
# Node 6: Driver Liaison Agent
# =============================================================================
def driver_liaison_node(state: AllocationState) -> Dict[str, Any]:
"""LangGraph node #6: Driver Liaison - per-driver comfort band negotiation."""
run_id = state.allocation_run_id
_publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "STARTED", {"num_drivers": len(state.driver_models)})
from app.schemas.agent_schemas import AllocationItem
liaison_agent = DriverLiaisonAgent()
final_proposal = state.final_proposal or state.route_proposal_1
final_fairness = state.final_fairness or state.fairness_check_1
sorted_allocations = sorted(final_proposal["allocation"], key=lambda x: x["effort"], reverse=True)
driver_proposals: List[DriverAssignmentProposal] = []
for rank, alloc_item in enumerate(sorted_allocations, start=1):
driver_proposals.append(DriverAssignmentProposal(
driver_id=str(alloc_item["driver_id"]), route_id=str(alloc_item["route_id"]),
effort=alloc_item["effort"], rank_in_team=rank,
))
metrics = final_fairness["metrics"]
driver_context_objs: Dict[str, DriverContext] = {}
for driver_id, context_dict in (state.driver_contexts or {}).items():
driver_context_objs[driver_id] = DriverContext(**context_dict)
negotiation_result = liaison_agent.run_for_all_drivers(
proposals=driver_proposals, driver_contexts=driver_context_objs,
effort_matrix=state.effort_matrix["matrix"], driver_ids=state.effort_matrix["driver_ids"],
route_ids=state.effort_matrix["route_ids"],
global_avg_effort=metrics["avg_effort"], global_std_effort=metrics["std_dev"],
)
liaison_dict = {
"decisions": [d.model_dump() if hasattr(d, 'model_dump') else d for d in negotiation_result.decisions],
"num_accept": negotiation_result.num_accept,
"num_counter": negotiation_result.num_counter,
"num_force_accept": negotiation_result.num_force_accept,
}
log_entry = _create_decision_log(
agent_name="DRIVER_LIAISON", step_type="NEGOTIATION_DECISIONS",
input_snapshot=liaison_agent.get_input_snapshot(driver_proposals, metrics["avg_effort"], metrics["std_dev"]),
output_snapshot=liaison_agent.get_output_snapshot(negotiation_result),
)
_publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "COMPLETED", {
"num_accept": negotiation_result.num_accept, "num_counter": negotiation_result.num_counter,
})
return {"liaison_feedback": liaison_dict, "decision_logs": state.decision_logs + [log_entry]}
# =============================================================================
# Node 7: Final Resolution Agent
# =============================================================================
def final_resolution_node(state: AllocationState) -> Dict[str, Any]:
"""LangGraph node #7: Final Resolution - resolves COUNTER decisions via swaps."""
run_id = state.allocation_run_id
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem, FairnessMetrics, DriverLiaisonDecision
counter_decisions = [d for d in state.liaison_feedback["decisions"] if d["decision"] == "COUNTER"]
if not counter_decisions:
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {"reason": "no_counters", "swaps_applied": 0})
return {"resolution_result": {"swaps_applied": []}}
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "STARTED", {"num_counters": len(counter_decisions)})
resolution_agent = FinalResolutionAgent()
final_proposal = state.final_proposal or state.route_proposal_1
final_fairness = state.final_fairness or state.fairness_check_1
approved_proposal = RoutePlanResult(
allocation=[AllocationItem(**a) for a in final_proposal["allocation"]],
total_effort=final_proposal["total_effort"],
avg_effort=final_proposal.get("avg_effort", final_proposal["total_effort"] / max(len(final_proposal["allocation"]), 1)),
solver_status=final_proposal.get("solver_status", "OPTIMAL"),
proposal_number=final_proposal["proposal_number"],
per_driver_effort=final_proposal["per_driver_effort"],
)
decisions = [DriverLiaisonDecision(**d) for d in state.liaison_feedback["decisions"]]
current_metrics = FairnessMetrics(**final_fairness["metrics"])
resolution_result = resolution_agent.resolve_counters(
approved_proposal=approved_proposal, decisions=decisions,
effort_matrix=state.effort_matrix["matrix"],
driver_ids=state.effort_matrix["driver_ids"], route_ids=state.effort_matrix["route_ids"],
current_metrics=current_metrics,
)
resolution_dict = {
"swaps_applied": [s.model_dump() if hasattr(s, 'model_dump') else s for s in resolution_result.swaps_applied],
"allocation": resolution_result.allocation,
"per_driver_effort": resolution_result.per_driver_effort,
"metrics": resolution_result.metrics,
}
log_entry = _create_decision_log(
agent_name="FINAL_RESOLUTION", step_type="SWAP_RESOLUTION",
input_snapshot=resolution_agent.get_input_snapshot(len(counter_decisions), current_metrics, final_fairness["metrics"]["avg_effort"]),
output_snapshot=resolution_agent.get_output_snapshot(resolution_result),
)
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {"swaps_applied": len(resolution_result.swaps_applied)})
updated_effort = state.final_per_driver_effort.copy() if state.final_per_driver_effort else {}
if resolution_result.swaps_applied:
updated_effort = resolution_result.per_driver_effort
return {"resolution_result": resolution_dict, "final_per_driver_effort": updated_effort, "decision_logs": state.decision_logs + [log_entry]}
# =============================================================================
# Node 8: Explainability Agent
# =============================================================================
def explainability_node(state: AllocationState) -> Dict[str, Any]:
"""LangGraph node #8: Explainability Agent — generates per-driver explanations."""
run_id = state.allocation_run_id
_publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "STARTED", {"num_drivers": len(state.driver_models)})
explain_agent = ExplainabilityAgent()
final_proposal = state.final_proposal or state.route_proposal_1
final_fairness = state.final_fairness or state.fairness_check_1
final_per_driver_effort = state.final_per_driver_effort or final_proposal["per_driver_effort"]
metrics = final_fairness["metrics"]
avg_effort = metrics["avg_effort"]
route_by_id = {str(r["id"]): r for r in state.route_models}
driver_by_id = {str(d["id"]): d for d in state.driver_models}
route_dict_by_id = {str(r["id"]): rd for r, rd in zip(state.route_models, state.route_dicts)} if state.route_dicts else {}
sorted_efforts = sorted(final_per_driver_effort.items(), key=lambda x: x[1], reverse=True)
rank_by_driver = {did: idx + 1 for idx, (did, _) in enumerate(sorted_efforts)}
num_drivers = len(final_per_driver_effort)
liaison_by_driver = {}
if state.liaison_feedback:
for decision in state.liaison_feedback["decisions"]:
liaison_by_driver[decision["driver_id"]] = decision
swapped_drivers = set()
if state.resolution_result and state.resolution_result.get("swaps_applied"):
for swap in state.resolution_result["swaps_applied"]:
swapped_drivers.add(swap.get("driver_a", ""))
swapped_drivers.add(swap.get("driver_b", ""))
explanations: Dict[str, Dict[str, Any]] = {}
category_counts: Dict[str, int] = {}
for alloc_item in final_proposal["allocation"]:
driver_id_str = str(alloc_item["driver_id"])
route_id_str = str(alloc_item["route_id"])
driver = driver_by_id.get(driver_id_str, {})
route = route_by_id.get(route_id_str, {})
effort = final_per_driver_effort.get(driver_id_str, alloc_item["effort"])
fairness_score = calculate_fairness_score(effort, avg_effort)
driver_context = (state.driver_contexts or {}).get(driver_id_str, {})
history_efforts = [driver_context.get("recent_avg_effort", avg_effort)] if driver_context else []
history_hard_days = driver_context.get("recent_hard_days", 0) if driver_context else 0
breakdown_key = f"{driver_id_str}:{route_id_str}"
effort_breakdown_data = state.effort_matrix.get("breakdown", {}).get(breakdown_key, {})
effort_breakdown = {
"physical_effort": effort_breakdown_data.get("physical_effort", 0),
"route_complexity": effort_breakdown_data.get("route_complexity", 0),
"time_pressure": effort_breakdown_data.get("time_pressure", 0),
}
liaison_decision = liaison_by_driver.get(driver_id_str)
is_recovery = history_hard_days >= 3 and effort < avg_effort * 0.85
explain_input = DriverExplanationInput(
driver_id=driver_id_str, driver_name=driver.get("name", "Driver"),
num_drivers=num_drivers, today_effort=effort,
today_rank=rank_by_driver.get(driver_id_str, num_drivers),
route_id=route_id_str,
route_summary={"num_packages": route.get("num_packages", 0), "total_weight_kg": route.get("total_weight_kg", 0), "num_stops": route.get("num_stops", 0), "difficulty_score": route.get("route_difficulty_score", 0), "estimated_time_minutes": route.get("estimated_time_minutes", 0)},
effort_breakdown=effort_breakdown,
global_avg_effort=avg_effort, global_std_effort=metrics["std_dev"],
global_gini_index=metrics["gini_index"], global_max_gap=metrics["max_gap"],
history_efforts_last_7_days=history_efforts,
history_hard_days_last_7=history_hard_days, is_recovery_day=is_recovery,
had_manual_override=False,
liaison_decision=liaison_decision["decision"] if liaison_decision else None,
swap_applied=driver_id_str in swapped_drivers,
)
explain_output = explain_agent.build_explanation_for_driver(explain_input)
category_counts[explain_output.category] = category_counts.get(explain_output.category, 0) + 1
explanations[driver_id_str] = {
"driver_explanation": explain_output.driver_explanation,
"admin_explanation": explain_output.admin_explanation,
"category": explain_output.category,
}
log_entry = _create_decision_log(
agent_name="EXPLAINABILITY", step_type="EXPLANATIONS_GENERATED",
input_snapshot=explain_agent.get_input_snapshot(num_drivers=num_drivers, avg_effort=avg_effort, std_effort=metrics["std_dev"], gini_index=metrics["gini_index"], category_counts=category_counts),
output_snapshot=explain_agent.get_output_snapshot(total_explanations=len(explanations), category_counts=category_counts),
)
_publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "COMPLETED", {"total_explanations": len(explanations), "categories": category_counts})
return {"explanations": explanations, "decision_logs": state.decision_logs + [log_entry]}
# =============================================================================
# Conditional Edge Functions
# =============================================================================
def should_reoptimize(state: AllocationState) -> str:
"""Conditional: re-optimize if fairness check 1 says REOPTIMIZE and no proposal 2 yet."""
if state.fairness_check_1 and state.fairness_check_1.get("status") == "REOPTIMIZE":
if not state.route_proposal_2:
return "reoptimize"
return "continue"
def has_counter_decisions(state: AllocationState) -> str:
"""Conditional: check if any COUNTER decisions need resolution."""
if state.liaison_feedback:
if sum(1 for d in state.liaison_feedback["decisions"] if d["decision"] == "COUNTER") > 0:
return "resolve"
return "skip"