| """ |
| LangGraph node wrappers for Fair Dispatch agents. |
| Each node wraps an existing agent with minimal changes, preserving the original logic. |
| |
| PRODUCTION FIXES APPLIED: |
| - ModelWrapper.is_ev: handles "EV", "ELECTRIC", VehicleType.EV enum values |
| - _publish_event_sync: improved reliability with asyncio.ensure_future |
| """ |
|
|
| from datetime import datetime |
| from typing import Dict, Any, List, Optional, Tuple |
| from uuid import UUID |
| import asyncio |
| import logging |
|
|
| from app.schemas.allocation_state import AllocationState |
| from app.schemas.agent_schemas import ( |
| FairnessThresholds, |
| DriverAssignmentProposal, |
| DriverContext, |
| ) |
| from app.services.ml_effort_agent import MLEffortAgent |
| from app.services.route_planner_agent import RoutePlannerAgent |
| from app.services.fairness_manager_agent import FairnessManagerAgent |
| from app.services.driver_liaison_agent import DriverLiaisonAgent |
| from app.services.final_resolution import FinalResolutionAgent |
| from app.services.explainability import ExplainabilityAgent |
| from app.schemas.explainability import DriverExplanationInput |
| from app.services.fairness import calculate_fairness_score |
| from app.core.events import agent_event_bus, make_agent_event |
|
|
| logger = logging.getLogger("fairrelay.langgraph") |
|
|
|
|
| class ModelWrapper: |
| """Helper to wrap dicts as objects for agent compatibility.""" |
| def __init__(self, data: Dict[str, Any]): |
| self._data = data |
| |
| def __getattr__(self, name: str) -> Any: |
| return self._data.get(name) |
| |
| @property |
| def is_ev(self) -> bool: |
| """Check if driver has an EV - handles all possible enum/string formats.""" |
| vt = self._data.get("vehicle_type", "") |
| vt_str = str(vt).upper() |
| return vt_str in ("EV", "ELECTRIC", "VEHICLETYPE.EV") |
|
|
|
|
| def _publish_event_sync( |
| allocation_run_id: Optional[str], |
| agent_name: str, |
| step_type: str, |
| state: str, |
| payload: Optional[Dict[str, Any]] = None, |
| ) -> None: |
| """ |
| Publish an agent event synchronously (fire-and-forget). |
| Used by LangGraph nodes which are synchronous functions. |
| |
| Uses asyncio.ensure_future for reliable delivery when a loop is running. |
| """ |
| if not allocation_run_id: |
| return |
| |
| event = make_agent_event( |
| allocation_run_id=allocation_run_id, |
| agent_name=agent_name, |
| step_type=step_type, |
| state=state, |
| payload=payload, |
| ) |
| |
| |
| try: |
| loop = asyncio.get_running_loop() |
| asyncio.ensure_future(agent_event_bus.publish(event), loop=loop) |
| except RuntimeError: |
| |
| |
| try: |
| asyncio.run(agent_event_bus.publish(event)) |
| except Exception as e: |
| logger.warning(f"Failed to publish agent event: {e}") |
|
|
|
|
| def _create_decision_log( |
| agent_name: str, |
| step_type: str, |
| input_snapshot: Dict[str, Any], |
| output_snapshot: Dict[str, Any], |
| ) -> Dict[str, Any]: |
| """Create a decision log entry compatible with DecisionLog model.""" |
| return { |
| "timestamp": datetime.utcnow().isoformat(), |
| "agent_name": agent_name, |
| "step_type": step_type, |
| "input_snapshot": input_snapshot, |
| "output_snapshot": output_snapshot, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def ml_effort_node(state: AllocationState) -> Dict[str, Any]: |
| """ |
| LangGraph node #1: ML Effort Agent. |
| Computes effort matrix for all driver-route pairs using MLEffortAgent. |
| """ |
| run_id = state.allocation_run_id |
| |
| _publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "STARTED", { |
| "num_drivers": len(state.driver_models), |
| "num_routes": len(state.route_models), |
| }) |
| |
| ml_agent = MLEffortAgent() |
| |
| ev_config = { |
| "safety_margin_pct": state.config_used.get("ev_safety_margin_pct", 10.0) if state.config_used else 10.0, |
| "charging_penalty_weight": state.config_used.get("ev_charging_penalty_weight", 0.3) if state.config_used else 0.3, |
| } |
| |
| drivers = [ModelWrapper(d) for d in state.driver_models] |
| routes = [ModelWrapper(r) for r in state.route_models] |
| |
| effort_result = ml_agent.compute_effort_matrix(drivers=drivers, routes=routes, ev_config=ev_config) |
| |
| effort_dict = { |
| "matrix": effort_result.matrix, |
| "driver_ids": effort_result.driver_ids, |
| "route_ids": effort_result.route_ids, |
| "breakdown": {k: v.model_dump() if hasattr(v, 'model_dump') else v for k, v in effort_result.breakdown.items()}, |
| "stats": effort_result.stats, |
| "infeasible_pairs": list(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else [], |
| } |
| |
| log_entry = _create_decision_log( |
| agent_name="ML_EFFORT", step_type="MATRIX_GENERATION", |
| input_snapshot=ml_agent.get_input_snapshot(drivers, routes), |
| output_snapshot={**ml_agent.get_output_snapshot(effort_result), "num_infeasible_ev_pairs": len(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else 0}, |
| ) |
| |
| _publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "COMPLETED", { |
| "min_effort": effort_result.stats.get("min", 0), |
| "max_effort": effort_result.stats.get("max", 0), |
| "avg_effort": effort_result.stats.get("avg", 0), |
| }) |
| |
| return {"effort_matrix": effort_dict, "decision_logs": state.decision_logs + [log_entry]} |
|
|
|
|
| |
| |
| |
|
|
| def route_planner_node(state: AllocationState) -> Dict[str, Any]: |
| """LangGraph node #2: Route Planner Agent - Proposal 1 (OR-Tools optimization).""" |
| run_id = state.allocation_run_id |
| |
| _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "STARTED", { |
| "num_drivers": len(state.driver_models), "num_routes": len(state.route_models), |
| }) |
| |
| planner_agent = RoutePlannerAgent() |
| from app.schemas.agent_schemas import EffortMatrixResult |
| |
| matrix = state.effort_matrix["matrix"] |
| stats = state.effort_matrix.get("stats") or {"min": 0, "max": 0, "avg": 0} |
| |
| effort_result = EffortMatrixResult( |
| matrix=matrix, driver_ids=state.effort_matrix["driver_ids"], |
| route_ids=state.effort_matrix["route_ids"], breakdown={}, stats=stats, |
| infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])), |
| ) |
| |
| recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0 |
| drivers = [ModelWrapper(d) for d in state.driver_models] |
| routes = [ModelWrapper(r) for r in state.route_models] |
| |
| proposal1 = planner_agent.plan( |
| effort_result=effort_result, drivers=drivers, routes=routes, |
| recovery_targets=state.recovery_targets or {}, |
| recovery_penalty_weight=recovery_penalty_weight, proposal_number=1, |
| ) |
| |
| proposal_dict = { |
| "allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal1.allocation], |
| "total_effort": proposal1.total_effort, "avg_effort": proposal1.avg_effort, |
| "solver_status": proposal1.solver_status, "proposal_number": proposal1.proposal_number, |
| "per_driver_effort": proposal1.per_driver_effort, |
| } |
| |
| log_entry = _create_decision_log( |
| agent_name="ROUTE_PLANNER", step_type="PROPOSAL_1", |
| input_snapshot=planner_agent.get_input_snapshot(effort_result), |
| output_snapshot=planner_agent.get_output_snapshot(proposal1), |
| ) |
| |
| _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "COMPLETED", { |
| "total_effort": proposal1.total_effort, "num_assignments": len(proposal1.allocation), |
| "solver_status": proposal1.solver_status, |
| }) |
| |
| return {"route_proposal_1": proposal_dict, "decision_logs": state.decision_logs + [log_entry]} |
|
|
|
|
| |
| |
| |
|
|
| def fairness_check_node(state: AllocationState) -> Dict[str, Any]: |
| """LangGraph node #3: Fairness Manager Agent — evaluates Gini/stddev/max_gap.""" |
| run_id = state.allocation_run_id |
| |
| _publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_1", "STARTED", {"proposal_number": 1}) |
| |
| thresholds = FairnessThresholds( |
| gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33, |
| stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0, |
| max_gap_threshold=state.config_used.get("max_gap_threshold", 25.0) if state.config_used else 25.0, |
| ) |
| |
| fairness_agent = FairnessManagerAgent(thresholds=thresholds) |
| from app.schemas.agent_schemas import RoutePlanResult, AllocationItem |
| |
| proposal_to_check = state.route_proposal_1 |
| plan_result = RoutePlanResult( |
| allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]], |
| total_effort=proposal_to_check["total_effort"], |
| avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / max(len(proposal_to_check["allocation"]), 1)), |
| solver_status=proposal_to_check.get("solver_status", "OPTIMAL"), |
| proposal_number=1, per_driver_effort=proposal_to_check["per_driver_effort"], |
| ) |
| |
| fairness_result = fairness_agent.check(plan_result, proposal_number=1) |
| |
| fairness_dict = { |
| "status": fairness_result.status, "proposal_number": fairness_result.proposal_number, |
| "metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else { |
| "avg_effort": fairness_result.metrics.avg_effort, "std_dev": fairness_result.metrics.std_dev, |
| "gini_index": fairness_result.metrics.gini_index, "max_effort": fairness_result.metrics.max_effort, |
| "min_effort": fairness_result.metrics.min_effort, "max_gap": fairness_result.metrics.max_gap, |
| }, |
| "recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None, |
| } |
| |
| log_entry = _create_decision_log( |
| agent_name="FAIRNESS_MANAGER", step_type="FAIRNESS_CHECK_PROPOSAL_1", |
| input_snapshot=fairness_agent.get_input_snapshot(plan_result), |
| output_snapshot=fairness_agent.get_output_snapshot(fairness_result), |
| ) |
| |
| _publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_1", "COMPLETED", { |
| "status": fairness_result.status, "gini_index": fairness_dict["metrics"]["gini_index"], |
| }) |
| |
| return {"fairness_check_1": fairness_dict, "decision_logs": state.decision_logs + [log_entry]} |
|
|
|
|
| def fairness_check_2_node(state: AllocationState) -> Dict[str, Any]: |
| """LangGraph node for second fairness check (after re-optimization).""" |
| run_id = state.allocation_run_id |
| |
| _publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_2", "STARTED", {"proposal_number": 2}) |
| |
| thresholds = FairnessThresholds( |
| gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33, |
| stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0, |
| max_gap_threshold=state.config_used.get("max_gap_threshold", 25.0) if state.config_used else 25.0, |
| ) |
| |
| fairness_agent = FairnessManagerAgent(thresholds=thresholds) |
| from app.schemas.agent_schemas import RoutePlanResult, AllocationItem |
| |
| proposal_to_check = state.route_proposal_2 |
| plan_result = RoutePlanResult( |
| allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]], |
| total_effort=proposal_to_check["total_effort"], |
| avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / max(len(proposal_to_check["allocation"]), 1)), |
| solver_status=proposal_to_check.get("solver_status", "OPTIMAL"), |
| proposal_number=2, per_driver_effort=proposal_to_check["per_driver_effort"], |
| ) |
| |
| fairness_result = fairness_agent.check(plan_result, proposal_number=2) |
| |
| fairness_dict = { |
| "status": fairness_result.status, "proposal_number": 2, |
| "metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else { |
| "avg_effort": fairness_result.metrics.avg_effort, "std_dev": fairness_result.metrics.std_dev, |
| "gini_index": fairness_result.metrics.gini_index, "max_effort": fairness_result.metrics.max_effort, |
| "min_effort": fairness_result.metrics.min_effort, "max_gap": fairness_result.metrics.max_gap, |
| }, |
| "recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None, |
| } |
| |
| log_entry = _create_decision_log( |
| agent_name="FAIRNESS_MANAGER", step_type="FAIRNESS_CHECK_PROPOSAL_2", |
| input_snapshot=fairness_agent.get_input_snapshot(plan_result), |
| output_snapshot=fairness_agent.get_output_snapshot(fairness_result), |
| ) |
| |
| _publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_2", "COMPLETED", { |
| "status": fairness_result.status, "gini_index": fairness_dict["metrics"]["gini_index"], |
| }) |
| |
| return {"fairness_check_2": fairness_dict, "decision_logs": state.decision_logs + [log_entry]} |
|
|
|
|
| |
| |
| |
|
|
| def route_planner_reoptimize_node(state: AllocationState) -> Dict[str, Any]: |
| """LangGraph node #4: Route Planner - Proposal 2 with fairness penalties.""" |
| run_id = state.allocation_run_id |
| |
| _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "STARTED", {"reason": "fairness_reoptimization"}) |
| |
| planner_agent = RoutePlannerAgent() |
| from app.schemas.agent_schemas import EffortMatrixResult, FairnessRecommendations |
| |
| matrix = state.effort_matrix["matrix"] |
| stats = state.effort_matrix.get("stats") or {"min": 0, "max": 0, "avg": 0} |
| |
| effort_result = EffortMatrixResult( |
| matrix=matrix, driver_ids=state.effort_matrix["driver_ids"], |
| route_ids=state.effort_matrix["route_ids"], breakdown={}, stats=stats, |
| infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])), |
| ) |
| |
| recommendations_dict = state.fairness_check_1.get("recommendations") |
| penalties = {} |
| if recommendations_dict: |
| recommendations = FairnessRecommendations(**recommendations_dict) |
| penalties = planner_agent.build_penalties_from_recommendations(recommendations, state.route_proposal_1["per_driver_effort"]) |
| |
| recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0 |
| drivers = [ModelWrapper(d) for d in state.driver_models] |
| routes = [ModelWrapper(r) for r in state.route_models] |
| |
| proposal2 = planner_agent.plan( |
| effort_result=effort_result, drivers=drivers, routes=routes, |
| fairness_penalties=penalties, recovery_targets=state.recovery_targets or {}, |
| recovery_penalty_weight=recovery_penalty_weight, proposal_number=2, |
| ) |
| |
| proposal_dict = { |
| "allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal2.allocation], |
| "total_effort": proposal2.total_effort, "avg_effort": proposal2.avg_effort, |
| "solver_status": proposal2.solver_status, "proposal_number": 2, |
| "per_driver_effort": proposal2.per_driver_effort, |
| } |
| |
| log_entry = _create_decision_log( |
| agent_name="ROUTE_PLANNER", step_type="PROPOSAL_2", |
| input_snapshot=planner_agent.get_input_snapshot(effort_result, penalties), |
| output_snapshot=planner_agent.get_output_snapshot(proposal2), |
| ) |
| |
| _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "COMPLETED", { |
| "total_effort": proposal2.total_effort, "solver_status": proposal2.solver_status, |
| }) |
| |
| return {"route_proposal_2": proposal_dict, "decision_logs": state.decision_logs + [log_entry]} |
|
|
|
|
| |
| |
| |
|
|
| def select_final_proposal_node(state: AllocationState) -> Dict[str, Any]: |
| """Select best proposal based on fairness metrics comparison.""" |
| final_proposal = state.route_proposal_1 |
| final_fairness = state.fairness_check_1 |
| |
| if state.route_proposal_2 and state.fairness_check_2: |
| check1_metrics = state.fairness_check_1["metrics"] |
| check2_metrics = state.fairness_check_2["metrics"] |
| if (check2_metrics["gini_index"] <= check1_metrics["gini_index"] or |
| check2_metrics["max_gap"] < check1_metrics["max_gap"]): |
| final_proposal = state.route_proposal_2 |
| final_fairness = state.fairness_check_2 |
| |
| return {"final_proposal": final_proposal, "final_fairness": final_fairness, "final_per_driver_effort": final_proposal["per_driver_effort"]} |
|
|
|
|
| |
| |
| |
|
|
| def driver_liaison_node(state: AllocationState) -> Dict[str, Any]: |
| """LangGraph node #6: Driver Liaison - per-driver comfort band negotiation.""" |
| run_id = state.allocation_run_id |
| |
| _publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "STARTED", {"num_drivers": len(state.driver_models)}) |
| |
| from app.schemas.agent_schemas import AllocationItem |
| liaison_agent = DriverLiaisonAgent() |
| |
| final_proposal = state.final_proposal or state.route_proposal_1 |
| final_fairness = state.final_fairness or state.fairness_check_1 |
| |
| sorted_allocations = sorted(final_proposal["allocation"], key=lambda x: x["effort"], reverse=True) |
| driver_proposals: List[DriverAssignmentProposal] = [] |
| for rank, alloc_item in enumerate(sorted_allocations, start=1): |
| driver_proposals.append(DriverAssignmentProposal( |
| driver_id=str(alloc_item["driver_id"]), route_id=str(alloc_item["route_id"]), |
| effort=alloc_item["effort"], rank_in_team=rank, |
| )) |
| |
| metrics = final_fairness["metrics"] |
| driver_context_objs: Dict[str, DriverContext] = {} |
| for driver_id, context_dict in (state.driver_contexts or {}).items(): |
| driver_context_objs[driver_id] = DriverContext(**context_dict) |
| |
| negotiation_result = liaison_agent.run_for_all_drivers( |
| proposals=driver_proposals, driver_contexts=driver_context_objs, |
| effort_matrix=state.effort_matrix["matrix"], driver_ids=state.effort_matrix["driver_ids"], |
| route_ids=state.effort_matrix["route_ids"], |
| global_avg_effort=metrics["avg_effort"], global_std_effort=metrics["std_dev"], |
| ) |
| |
| liaison_dict = { |
| "decisions": [d.model_dump() if hasattr(d, 'model_dump') else d for d in negotiation_result.decisions], |
| "num_accept": negotiation_result.num_accept, |
| "num_counter": negotiation_result.num_counter, |
| "num_force_accept": negotiation_result.num_force_accept, |
| } |
| |
| log_entry = _create_decision_log( |
| agent_name="DRIVER_LIAISON", step_type="NEGOTIATION_DECISIONS", |
| input_snapshot=liaison_agent.get_input_snapshot(driver_proposals, metrics["avg_effort"], metrics["std_dev"]), |
| output_snapshot=liaison_agent.get_output_snapshot(negotiation_result), |
| ) |
| |
| _publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "COMPLETED", { |
| "num_accept": negotiation_result.num_accept, "num_counter": negotiation_result.num_counter, |
| }) |
| |
| return {"liaison_feedback": liaison_dict, "decision_logs": state.decision_logs + [log_entry]} |
|
|
|
|
| |
| |
| |
|
|
| def final_resolution_node(state: AllocationState) -> Dict[str, Any]: |
| """LangGraph node #7: Final Resolution - resolves COUNTER decisions via swaps.""" |
| run_id = state.allocation_run_id |
| from app.schemas.agent_schemas import RoutePlanResult, AllocationItem, FairnessMetrics, DriverLiaisonDecision |
| |
| counter_decisions = [d for d in state.liaison_feedback["decisions"] if d["decision"] == "COUNTER"] |
| |
| if not counter_decisions: |
| _publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {"reason": "no_counters", "swaps_applied": 0}) |
| return {"resolution_result": {"swaps_applied": []}} |
| |
| _publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "STARTED", {"num_counters": len(counter_decisions)}) |
| |
| resolution_agent = FinalResolutionAgent() |
| final_proposal = state.final_proposal or state.route_proposal_1 |
| final_fairness = state.final_fairness or state.fairness_check_1 |
| |
| approved_proposal = RoutePlanResult( |
| allocation=[AllocationItem(**a) for a in final_proposal["allocation"]], |
| total_effort=final_proposal["total_effort"], |
| avg_effort=final_proposal.get("avg_effort", final_proposal["total_effort"] / max(len(final_proposal["allocation"]), 1)), |
| solver_status=final_proposal.get("solver_status", "OPTIMAL"), |
| proposal_number=final_proposal["proposal_number"], |
| per_driver_effort=final_proposal["per_driver_effort"], |
| ) |
| |
| decisions = [DriverLiaisonDecision(**d) for d in state.liaison_feedback["decisions"]] |
| current_metrics = FairnessMetrics(**final_fairness["metrics"]) |
| |
| resolution_result = resolution_agent.resolve_counters( |
| approved_proposal=approved_proposal, decisions=decisions, |
| effort_matrix=state.effort_matrix["matrix"], |
| driver_ids=state.effort_matrix["driver_ids"], route_ids=state.effort_matrix["route_ids"], |
| current_metrics=current_metrics, |
| ) |
| |
| resolution_dict = { |
| "swaps_applied": [s.model_dump() if hasattr(s, 'model_dump') else s for s in resolution_result.swaps_applied], |
| "allocation": resolution_result.allocation, |
| "per_driver_effort": resolution_result.per_driver_effort, |
| "metrics": resolution_result.metrics, |
| } |
| |
| log_entry = _create_decision_log( |
| agent_name="FINAL_RESOLUTION", step_type="SWAP_RESOLUTION", |
| input_snapshot=resolution_agent.get_input_snapshot(len(counter_decisions), current_metrics, final_fairness["metrics"]["avg_effort"]), |
| output_snapshot=resolution_agent.get_output_snapshot(resolution_result), |
| ) |
| |
| _publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {"swaps_applied": len(resolution_result.swaps_applied)}) |
| |
| updated_effort = state.final_per_driver_effort.copy() if state.final_per_driver_effort else {} |
| if resolution_result.swaps_applied: |
| updated_effort = resolution_result.per_driver_effort |
| |
| return {"resolution_result": resolution_dict, "final_per_driver_effort": updated_effort, "decision_logs": state.decision_logs + [log_entry]} |
|
|
|
|
| |
| |
| |
|
|
| def explainability_node(state: AllocationState) -> Dict[str, Any]: |
| """LangGraph node #8: Explainability Agent — generates per-driver explanations.""" |
| run_id = state.allocation_run_id |
| |
| _publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "STARTED", {"num_drivers": len(state.driver_models)}) |
| |
| explain_agent = ExplainabilityAgent() |
| final_proposal = state.final_proposal or state.route_proposal_1 |
| final_fairness = state.final_fairness or state.fairness_check_1 |
| final_per_driver_effort = state.final_per_driver_effort or final_proposal["per_driver_effort"] |
| |
| metrics = final_fairness["metrics"] |
| avg_effort = metrics["avg_effort"] |
| |
| route_by_id = {str(r["id"]): r for r in state.route_models} |
| driver_by_id = {str(d["id"]): d for d in state.driver_models} |
| route_dict_by_id = {str(r["id"]): rd for r, rd in zip(state.route_models, state.route_dicts)} if state.route_dicts else {} |
| |
| sorted_efforts = sorted(final_per_driver_effort.items(), key=lambda x: x[1], reverse=True) |
| rank_by_driver = {did: idx + 1 for idx, (did, _) in enumerate(sorted_efforts)} |
| num_drivers = len(final_per_driver_effort) |
| |
| liaison_by_driver = {} |
| if state.liaison_feedback: |
| for decision in state.liaison_feedback["decisions"]: |
| liaison_by_driver[decision["driver_id"]] = decision |
| |
| swapped_drivers = set() |
| if state.resolution_result and state.resolution_result.get("swaps_applied"): |
| for swap in state.resolution_result["swaps_applied"]: |
| swapped_drivers.add(swap.get("driver_a", "")) |
| swapped_drivers.add(swap.get("driver_b", "")) |
| |
| explanations: Dict[str, Dict[str, Any]] = {} |
| category_counts: Dict[str, int] = {} |
| |
| for alloc_item in final_proposal["allocation"]: |
| driver_id_str = str(alloc_item["driver_id"]) |
| route_id_str = str(alloc_item["route_id"]) |
| |
| driver = driver_by_id.get(driver_id_str, {}) |
| route = route_by_id.get(route_id_str, {}) |
| |
| effort = final_per_driver_effort.get(driver_id_str, alloc_item["effort"]) |
| fairness_score = calculate_fairness_score(effort, avg_effort) |
| driver_context = (state.driver_contexts or {}).get(driver_id_str, {}) |
| history_efforts = [driver_context.get("recent_avg_effort", avg_effort)] if driver_context else [] |
| history_hard_days = driver_context.get("recent_hard_days", 0) if driver_context else 0 |
| |
| breakdown_key = f"{driver_id_str}:{route_id_str}" |
| effort_breakdown_data = state.effort_matrix.get("breakdown", {}).get(breakdown_key, {}) |
| effort_breakdown = { |
| "physical_effort": effort_breakdown_data.get("physical_effort", 0), |
| "route_complexity": effort_breakdown_data.get("route_complexity", 0), |
| "time_pressure": effort_breakdown_data.get("time_pressure", 0), |
| } |
| |
| liaison_decision = liaison_by_driver.get(driver_id_str) |
| is_recovery = history_hard_days >= 3 and effort < avg_effort * 0.85 |
| |
| explain_input = DriverExplanationInput( |
| driver_id=driver_id_str, driver_name=driver.get("name", "Driver"), |
| num_drivers=num_drivers, today_effort=effort, |
| today_rank=rank_by_driver.get(driver_id_str, num_drivers), |
| route_id=route_id_str, |
| route_summary={"num_packages": route.get("num_packages", 0), "total_weight_kg": route.get("total_weight_kg", 0), "num_stops": route.get("num_stops", 0), "difficulty_score": route.get("route_difficulty_score", 0), "estimated_time_minutes": route.get("estimated_time_minutes", 0)}, |
| effort_breakdown=effort_breakdown, |
| global_avg_effort=avg_effort, global_std_effort=metrics["std_dev"], |
| global_gini_index=metrics["gini_index"], global_max_gap=metrics["max_gap"], |
| history_efforts_last_7_days=history_efforts, |
| history_hard_days_last_7=history_hard_days, is_recovery_day=is_recovery, |
| had_manual_override=False, |
| liaison_decision=liaison_decision["decision"] if liaison_decision else None, |
| swap_applied=driver_id_str in swapped_drivers, |
| ) |
| |
| explain_output = explain_agent.build_explanation_for_driver(explain_input) |
| category_counts[explain_output.category] = category_counts.get(explain_output.category, 0) + 1 |
| explanations[driver_id_str] = { |
| "driver_explanation": explain_output.driver_explanation, |
| "admin_explanation": explain_output.admin_explanation, |
| "category": explain_output.category, |
| } |
| |
| log_entry = _create_decision_log( |
| agent_name="EXPLAINABILITY", step_type="EXPLANATIONS_GENERATED", |
| input_snapshot=explain_agent.get_input_snapshot(num_drivers=num_drivers, avg_effort=avg_effort, std_effort=metrics["std_dev"], gini_index=metrics["gini_index"], category_counts=category_counts), |
| output_snapshot=explain_agent.get_output_snapshot(total_explanations=len(explanations), category_counts=category_counts), |
| ) |
| |
| _publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "COMPLETED", {"total_explanations": len(explanations), "categories": category_counts}) |
| |
| return {"explanations": explanations, "decision_logs": state.decision_logs + [log_entry]} |
|
|
|
|
| |
| |
| |
|
|
| def should_reoptimize(state: AllocationState) -> str: |
| """Conditional: re-optimize if fairness check 1 says REOPTIMIZE and no proposal 2 yet.""" |
| if state.fairness_check_1 and state.fairness_check_1.get("status") == "REOPTIMIZE": |
| if not state.route_proposal_2: |
| return "reoptimize" |
| return "continue" |
|
|
|
|
| def has_counter_decisions(state: AllocationState) -> str: |
| """Conditional: check if any COUNTER decisions need resolution.""" |
| if state.liaison_feedback: |
| if sum(1 for d in state.liaison_feedback["decisions"] if d["decision"] == "COUNTER") > 0: |
| return "resolve" |
| return "skip" |
|
|