fix: ModelWrapper.is_ev enum comparison + improve event publish reliability for real-time SSE
Browse files- brain/app/services/langgraph_nodes.py +133 -457
brain/app/services/langgraph_nodes.py
CHANGED
|
@@ -1,12 +1,17 @@
|
|
| 1 |
"""
|
| 2 |
LangGraph node wrappers for Fair Dispatch agents.
|
| 3 |
Each node wraps an existing agent with minimal changes, preserving the original logic.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from datetime import datetime
|
| 7 |
from typing import Dict, Any, List, Optional, Tuple
|
| 8 |
from uuid import UUID
|
| 9 |
import asyncio
|
|
|
|
| 10 |
|
| 11 |
from app.schemas.allocation_state import AllocationState
|
| 12 |
from app.schemas.agent_schemas import (
|
|
@@ -24,6 +29,8 @@ from app.schemas.explainability import DriverExplanationInput
|
|
| 24 |
from app.services.fairness import calculate_fairness_score
|
| 25 |
from app.core.events import agent_event_bus, make_agent_event
|
| 26 |
|
|
|
|
|
|
|
| 27 |
|
| 28 |
class ModelWrapper:
|
| 29 |
"""Helper to wrap dicts as objects for agent compatibility."""
|
|
@@ -35,7 +42,10 @@ class ModelWrapper:
|
|
| 35 |
|
| 36 |
@property
|
| 37 |
def is_ev(self) -> bool:
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def _publish_event_sync(
|
|
@@ -48,6 +58,8 @@ def _publish_event_sync(
|
|
| 48 |
"""
|
| 49 |
Publish an agent event synchronously (fire-and-forget).
|
| 50 |
Used by LangGraph nodes which are synchronous functions.
|
|
|
|
|
|
|
| 51 |
"""
|
| 52 |
if not allocation_run_id:
|
| 53 |
return
|
|
@@ -60,13 +72,17 @@ def _publish_event_sync(
|
|
| 60 |
payload=payload,
|
| 61 |
)
|
| 62 |
|
| 63 |
-
# Schedule async publish
|
| 64 |
try:
|
| 65 |
loop = asyncio.get_running_loop()
|
| 66 |
-
|
| 67 |
except RuntimeError:
|
| 68 |
-
# No running loop
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
def _create_decision_log(
|
|
@@ -92,71 +108,49 @@ def _create_decision_log(
|
|
| 92 |
def ml_effort_node(state: AllocationState) -> Dict[str, Any]:
|
| 93 |
"""
|
| 94 |
LangGraph node #1: ML Effort Agent.
|
| 95 |
-
|
| 96 |
Computes effort matrix for all driver-route pairs using MLEffortAgent.
|
| 97 |
-
WRAPS EXISTING AGENT - no logic changes.
|
| 98 |
"""
|
| 99 |
run_id = state.allocation_run_id
|
| 100 |
|
| 101 |
-
# Publish STARTED event
|
| 102 |
_publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "STARTED", {
|
| 103 |
"num_drivers": len(state.driver_models),
|
| 104 |
"num_routes": len(state.route_models),
|
| 105 |
})
|
| 106 |
|
| 107 |
-
# Initialize agent
|
| 108 |
ml_agent = MLEffortAgent()
|
| 109 |
|
| 110 |
-
# Get EV config from state
|
| 111 |
ev_config = {
|
| 112 |
"safety_margin_pct": state.config_used.get("ev_safety_margin_pct", 10.0) if state.config_used else 10.0,
|
| 113 |
"charging_penalty_weight": state.config_used.get("ev_charging_penalty_weight", 0.3) if state.config_used else 0.3,
|
| 114 |
}
|
| 115 |
|
| 116 |
-
# Wrap dicts as objects for agent compatibility
|
| 117 |
drivers = [ModelWrapper(d) for d in state.driver_models]
|
| 118 |
routes = [ModelWrapper(r) for r in state.route_models]
|
| 119 |
|
| 120 |
-
|
| 121 |
-
effort_result = ml_agent.compute_effort_matrix(
|
| 122 |
-
drivers=drivers,
|
| 123 |
-
routes=routes,
|
| 124 |
-
ev_config=ev_config,
|
| 125 |
-
)
|
| 126 |
|
| 127 |
-
# Serialize result for state
|
| 128 |
effort_dict = {
|
| 129 |
"matrix": effort_result.matrix,
|
| 130 |
"driver_ids": effort_result.driver_ids,
|
| 131 |
"route_ids": effort_result.route_ids,
|
| 132 |
-
"breakdown": {k: v.model_dump() if hasattr(v, 'model_dump') else v
|
| 133 |
-
for k, v in effort_result.breakdown.items()},
|
| 134 |
"stats": effort_result.stats,
|
| 135 |
"infeasible_pairs": list(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else [],
|
| 136 |
}
|
| 137 |
|
| 138 |
-
# Create decision log
|
| 139 |
log_entry = _create_decision_log(
|
| 140 |
-
agent_name="ML_EFFORT",
|
| 141 |
-
step_type="MATRIX_GENERATION",
|
| 142 |
input_snapshot=ml_agent.get_input_snapshot(drivers, routes),
|
| 143 |
-
output_snapshot={
|
| 144 |
-
**ml_agent.get_output_snapshot(effort_result),
|
| 145 |
-
"num_infeasible_ev_pairs": len(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else 0,
|
| 146 |
-
},
|
| 147 |
)
|
| 148 |
|
| 149 |
-
# Publish COMPLETED event
|
| 150 |
_publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "COMPLETED", {
|
| 151 |
"min_effort": effort_result.stats.get("min", 0),
|
| 152 |
"max_effort": effort_result.stats.get("max", 0),
|
| 153 |
"avg_effort": effort_result.stats.get("avg", 0),
|
| 154 |
})
|
| 155 |
|
| 156 |
-
return {
|
| 157 |
-
"effort_matrix": effort_dict,
|
| 158 |
-
"decision_logs": state.decision_logs + [log_entry],
|
| 159 |
-
}
|
| 160 |
|
| 161 |
|
| 162 |
# =============================================================================
|
|
@@ -164,91 +158,54 @@ def ml_effort_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 164 |
# =============================================================================
|
| 165 |
|
| 166 |
def route_planner_node(state: AllocationState) -> Dict[str, Any]:
|
| 167 |
-
"""
|
| 168 |
-
LangGraph node #2: Route Planner Agent - Proposal 1.
|
| 169 |
-
|
| 170 |
-
Generates optimal driver-route assignment using OR-Tools.
|
| 171 |
-
WRAPS EXISTING AGENT - no logic changes.
|
| 172 |
-
"""
|
| 173 |
run_id = state.allocation_run_id
|
| 174 |
|
| 175 |
-
# Publish STARTED event
|
| 176 |
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "STARTED", {
|
| 177 |
-
"num_drivers": len(state.driver_models),
|
| 178 |
-
"num_routes": len(state.route_models),
|
| 179 |
})
|
| 180 |
|
| 181 |
planner_agent = RoutePlannerAgent()
|
|
|
|
| 182 |
|
| 183 |
-
# Reconstruct EffortMatrixResult-like object for planner
|
| 184 |
-
from app.schemas.agent_schemas import EffortMatrixResult, EffortBreakdown
|
| 185 |
-
|
| 186 |
-
# Use stats from serialized state or compute if not available
|
| 187 |
matrix = state.effort_matrix["matrix"]
|
| 188 |
-
stats = state.effort_matrix.get("stats")
|
| 189 |
-
if not stats:
|
| 190 |
-
all_values = [v for row in matrix for v in row if v < float('inf')]
|
| 191 |
-
stats = {
|
| 192 |
-
"min": min(all_values) if all_values else 0.0,
|
| 193 |
-
"max": max(all_values) if all_values else 0.0,
|
| 194 |
-
"avg": sum(all_values) / len(all_values) if all_values else 0.0,
|
| 195 |
-
}
|
| 196 |
|
| 197 |
effort_result = EffortMatrixResult(
|
| 198 |
-
matrix=matrix,
|
| 199 |
-
|
| 200 |
-
route_ids=state.effort_matrix["route_ids"],
|
| 201 |
-
breakdown={}, # Simplified - full breakdown not needed for planning
|
| 202 |
-
stats=stats,
|
| 203 |
infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
|
| 204 |
)
|
| 205 |
|
| 206 |
-
# Get recovery penalty weight
|
| 207 |
recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
|
| 208 |
-
|
| 209 |
-
# Wrap dicts as objects for agent compatibility
|
| 210 |
drivers = [ModelWrapper(d) for d in state.driver_models]
|
| 211 |
routes = [ModelWrapper(r) for r in state.route_models]
|
| 212 |
|
| 213 |
-
# Generate Proposal 1 (EXISTING CODE - UNCHANGED)
|
| 214 |
proposal1 = planner_agent.plan(
|
| 215 |
-
effort_result=effort_result,
|
| 216 |
-
drivers=drivers,
|
| 217 |
-
routes=routes,
|
| 218 |
recovery_targets=state.recovery_targets or {},
|
| 219 |
-
recovery_penalty_weight=recovery_penalty_weight,
|
| 220 |
-
proposal_number=1,
|
| 221 |
)
|
| 222 |
|
| 223 |
-
# Serialize result
|
| 224 |
proposal_dict = {
|
| 225 |
"allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal1.allocation],
|
| 226 |
-
"total_effort": proposal1.total_effort,
|
| 227 |
-
"
|
| 228 |
-
"solver_status": proposal1.solver_status,
|
| 229 |
-
"proposal_number": proposal1.proposal_number,
|
| 230 |
"per_driver_effort": proposal1.per_driver_effort,
|
| 231 |
}
|
| 232 |
|
| 233 |
-
# Create decision log
|
| 234 |
log_entry = _create_decision_log(
|
| 235 |
-
agent_name="ROUTE_PLANNER",
|
| 236 |
-
step_type="PROPOSAL_1",
|
| 237 |
input_snapshot=planner_agent.get_input_snapshot(effort_result),
|
| 238 |
output_snapshot=planner_agent.get_output_snapshot(proposal1),
|
| 239 |
)
|
| 240 |
|
| 241 |
-
# Publish COMPLETED event
|
| 242 |
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "COMPLETED", {
|
| 243 |
-
"total_effort": proposal1.total_effort,
|
| 244 |
-
"num_assignments": len(proposal1.allocation),
|
| 245 |
"solver_status": proposal1.solver_status,
|
| 246 |
})
|
| 247 |
|
| 248 |
-
return {
|
| 249 |
-
"route_proposal_1": proposal_dict,
|
| 250 |
-
"decision_logs": state.decision_logs + [log_entry],
|
| 251 |
-
}
|
| 252 |
|
| 253 |
|
| 254 |
# =============================================================================
|
|
@@ -256,21 +213,11 @@ def route_planner_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 256 |
# =============================================================================
|
| 257 |
|
| 258 |
def fairness_check_node(state: AllocationState) -> Dict[str, Any]:
|
| 259 |
-
"""
|
| 260 |
-
LangGraph node #3: Fairness Manager Agent.
|
| 261 |
-
|
| 262 |
-
Evaluates fairness metrics and decides ACCEPT or REOPTIMIZE.
|
| 263 |
-
WRAPS EXISTING AGENT - no logic changes.
|
| 264 |
-
"""
|
| 265 |
run_id = state.allocation_run_id
|
| 266 |
-
proposal_number = 2 if state.route_proposal_2 else 1
|
| 267 |
|
| 268 |
-
|
| 269 |
-
_publish_event_sync(run_id, "FAIRNESS_MANAGER", f"FAIRNESS_CHECK_{proposal_number}", "STARTED", {
|
| 270 |
-
"proposal_number": proposal_number,
|
| 271 |
-
})
|
| 272 |
|
| 273 |
-
# Get thresholds from config
|
| 274 |
thresholds = FairnessThresholds(
|
| 275 |
gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
|
| 276 |
stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
|
|
@@ -278,84 +225,48 @@ def fairness_check_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 278 |
)
|
| 279 |
|
| 280 |
fairness_agent = FairnessManagerAgent(thresholds=thresholds)
|
| 281 |
-
|
| 282 |
-
# Reconstruct RoutePlanResult for fairness check
|
| 283 |
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
|
| 284 |
|
| 285 |
-
|
| 286 |
-
proposal_to_check = state.route_proposal_2 or state.route_proposal_1
|
| 287 |
-
|
| 288 |
plan_result = RoutePlanResult(
|
| 289 |
allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
|
| 290 |
total_effort=proposal_to_check["total_effort"],
|
| 291 |
-
avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / len(proposal_to_check["allocation"])
|
| 292 |
solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
|
| 293 |
-
proposal_number=
|
| 294 |
-
per_driver_effort=proposal_to_check["per_driver_effort"],
|
| 295 |
)
|
| 296 |
|
| 297 |
-
|
| 298 |
-
fairness_result = fairness_agent.check(plan_result, proposal_number=proposal_number)
|
| 299 |
|
| 300 |
-
# Serialize result
|
| 301 |
fairness_dict = {
|
| 302 |
-
"status": fairness_result.status,
|
| 303 |
-
"proposal_number": fairness_result.proposal_number,
|
| 304 |
"metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
|
| 305 |
-
"avg_effort": fairness_result.metrics.avg_effort,
|
| 306 |
-
"
|
| 307 |
-
"
|
| 308 |
-
"max_effort": fairness_result.metrics.max_effort,
|
| 309 |
-
"min_effort": fairness_result.metrics.min_effort,
|
| 310 |
-
"max_gap": fairness_result.metrics.max_gap,
|
| 311 |
},
|
| 312 |
"recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
|
| 313 |
}
|
| 314 |
|
| 315 |
-
# Create decision log
|
| 316 |
log_entry = _create_decision_log(
|
| 317 |
-
agent_name="FAIRNESS_MANAGER",
|
| 318 |
-
step_type=f"FAIRNESS_CHECK_PROPOSAL_{proposal_number}",
|
| 319 |
input_snapshot=fairness_agent.get_input_snapshot(plan_result),
|
| 320 |
output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
|
| 321 |
)
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
"status": fairness_result.status,
|
| 326 |
-
"gini_index": fairness_dict["metrics"]["gini_index"],
|
| 327 |
-
"std_dev": fairness_dict["metrics"]["std_dev"],
|
| 328 |
})
|
| 329 |
|
| 330 |
-
|
| 331 |
-
updates = {
|
| 332 |
-
"decision_logs": state.decision_logs + [log_entry],
|
| 333 |
-
}
|
| 334 |
-
|
| 335 |
-
if proposal_number == 1:
|
| 336 |
-
updates["fairness_check_1"] = fairness_dict
|
| 337 |
-
else:
|
| 338 |
-
updates["fairness_check_2"] = fairness_dict
|
| 339 |
-
|
| 340 |
-
return updates
|
| 341 |
|
| 342 |
|
| 343 |
def fairness_check_2_node(state: AllocationState) -> Dict[str, Any]:
|
| 344 |
-
"""
|
| 345 |
-
LangGraph node for second fairness check (dedicated wrapper).
|
| 346 |
-
|
| 347 |
-
This is a separate function to avoid LangGraph state key conflicts
|
| 348 |
-
when the same function is used for multiple nodes.
|
| 349 |
-
"""
|
| 350 |
run_id = state.allocation_run_id
|
| 351 |
-
proposal_number = 2 # Always proposal 2 for this node
|
| 352 |
|
| 353 |
-
|
| 354 |
-
_publish_event_sync(run_id, "FAIRNESS_MANAGER", f"FAIRNESS_CHECK_{proposal_number}", "STARTED", {
|
| 355 |
-
"proposal_number": proposal_number,
|
| 356 |
-
})
|
| 357 |
|
| 358 |
-
# Get thresholds from config
|
| 359 |
thresholds = FairnessThresholds(
|
| 360 |
gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
|
| 361 |
stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
|
|
@@ -363,164 +274,98 @@ def fairness_check_2_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 363 |
)
|
| 364 |
|
| 365 |
fairness_agent = FairnessManagerAgent(thresholds=thresholds)
|
| 366 |
-
|
| 367 |
-
# Reconstruct RoutePlanResult for fairness check
|
| 368 |
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
|
| 369 |
|
| 370 |
-
# Always check proposal 2 for this node
|
| 371 |
proposal_to_check = state.route_proposal_2
|
| 372 |
-
|
| 373 |
plan_result = RoutePlanResult(
|
| 374 |
allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
|
| 375 |
total_effort=proposal_to_check["total_effort"],
|
| 376 |
-
avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / len(proposal_to_check["allocation"])
|
| 377 |
solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
|
| 378 |
-
proposal_number=
|
| 379 |
-
per_driver_effort=proposal_to_check["per_driver_effort"],
|
| 380 |
)
|
| 381 |
|
| 382 |
-
|
| 383 |
-
fairness_result = fairness_agent.check(plan_result, proposal_number=proposal_number)
|
| 384 |
|
| 385 |
-
# Serialize result
|
| 386 |
fairness_dict = {
|
| 387 |
-
"status": fairness_result.status,
|
| 388 |
-
"proposal_number": fairness_result.proposal_number,
|
| 389 |
"metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
|
| 390 |
-
"avg_effort": fairness_result.metrics.avg_effort,
|
| 391 |
-
"
|
| 392 |
-
"
|
| 393 |
-
"max_effort": fairness_result.metrics.max_effort,
|
| 394 |
-
"min_effort": fairness_result.metrics.min_effort,
|
| 395 |
-
"max_gap": fairness_result.metrics.max_gap,
|
| 396 |
},
|
| 397 |
"recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
|
| 398 |
}
|
| 399 |
|
| 400 |
-
# Create decision log
|
| 401 |
log_entry = _create_decision_log(
|
| 402 |
-
agent_name="FAIRNESS_MANAGER",
|
| 403 |
-
step_type=f"FAIRNESS_CHECK_PROPOSAL_{proposal_number}",
|
| 404 |
input_snapshot=fairness_agent.get_input_snapshot(plan_result),
|
| 405 |
output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
|
| 406 |
)
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
"status": fairness_result.status,
|
| 411 |
-
"gini_index": fairness_dict["metrics"]["gini_index"],
|
| 412 |
-
"std_dev": fairness_dict["metrics"]["std_dev"],
|
| 413 |
})
|
| 414 |
|
| 415 |
-
return {
|
| 416 |
-
"fairness_check_2": fairness_dict,
|
| 417 |
-
"decision_logs": state.decision_logs + [log_entry],
|
| 418 |
-
}
|
| 419 |
|
| 420 |
|
| 421 |
# =============================================================================
|
| 422 |
-
# Node 4: Route Planner
|
| 423 |
# =============================================================================
|
| 424 |
|
| 425 |
-
|
| 426 |
def route_planner_reoptimize_node(state: AllocationState) -> Dict[str, Any]:
|
| 427 |
-
"""
|
| 428 |
-
LangGraph node #4: Route Planner Agent - Proposal 2 (re-optimization).
|
| 429 |
-
|
| 430 |
-
Re-runs OR-Tools with fairness penalties applied.
|
| 431 |
-
WRAPS EXISTING AGENT - no logic changes.
|
| 432 |
-
"""
|
| 433 |
run_id = state.allocation_run_id
|
| 434 |
|
| 435 |
-
|
| 436 |
-
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "STARTED", {
|
| 437 |
-
"reason": "fairness_reoptimization",
|
| 438 |
-
})
|
| 439 |
|
| 440 |
planner_agent = RoutePlannerAgent()
|
| 441 |
-
|
| 442 |
-
# Reconstruct effort result
|
| 443 |
from app.schemas.agent_schemas import EffortMatrixResult, FairnessRecommendations
|
| 444 |
|
| 445 |
-
# Use stats from serialized state or compute if not available
|
| 446 |
matrix = state.effort_matrix["matrix"]
|
| 447 |
-
stats = state.effort_matrix.get("stats")
|
| 448 |
-
if not stats:
|
| 449 |
-
all_values = [v for row in matrix for v in row if v < float('inf')]
|
| 450 |
-
stats = {
|
| 451 |
-
"min": min(all_values) if all_values else 0.0,
|
| 452 |
-
"max": max(all_values) if all_values else 0.0,
|
| 453 |
-
"avg": sum(all_values) / len(all_values) if all_values else 0.0,
|
| 454 |
-
}
|
| 455 |
|
| 456 |
effort_result = EffortMatrixResult(
|
| 457 |
-
matrix=matrix,
|
| 458 |
-
|
| 459 |
-
route_ids=state.effort_matrix["route_ids"],
|
| 460 |
-
breakdown={},
|
| 461 |
-
stats=stats,
|
| 462 |
infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
|
| 463 |
)
|
| 464 |
|
| 465 |
-
# Build penalties from recommendations
|
| 466 |
recommendations_dict = state.fairness_check_1.get("recommendations")
|
| 467 |
penalties = {}
|
| 468 |
-
|
| 469 |
if recommendations_dict:
|
| 470 |
recommendations = FairnessRecommendations(**recommendations_dict)
|
| 471 |
-
penalties = planner_agent.build_penalties_from_recommendations(
|
| 472 |
-
recommendations,
|
| 473 |
-
state.route_proposal_1["per_driver_effort"],
|
| 474 |
-
)
|
| 475 |
|
| 476 |
-
# Get recovery settings
|
| 477 |
recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
|
| 478 |
-
|
| 479 |
-
# Wrap dicts as objects for agent compatibility
|
| 480 |
drivers = [ModelWrapper(d) for d in state.driver_models]
|
| 481 |
routes = [ModelWrapper(r) for r in state.route_models]
|
| 482 |
|
| 483 |
-
# Generate Proposal 2 (EXISTING CODE - UNCHANGED)
|
| 484 |
proposal2 = planner_agent.plan(
|
| 485 |
-
effort_result=effort_result,
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
fairness_penalties=penalties,
|
| 489 |
-
recovery_targets=state.recovery_targets or {},
|
| 490 |
-
recovery_penalty_weight=recovery_penalty_weight,
|
| 491 |
-
proposal_number=2,
|
| 492 |
)
|
| 493 |
|
| 494 |
-
# Serialize result
|
| 495 |
proposal_dict = {
|
| 496 |
"allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal2.allocation],
|
| 497 |
-
"total_effort": proposal2.total_effort,
|
| 498 |
-
"
|
| 499 |
-
"solver_status": proposal2.solver_status,
|
| 500 |
-
"proposal_number": proposal2.proposal_number,
|
| 501 |
"per_driver_effort": proposal2.per_driver_effort,
|
| 502 |
}
|
| 503 |
|
| 504 |
-
# Create decision log
|
| 505 |
log_entry = _create_decision_log(
|
| 506 |
-
agent_name="ROUTE_PLANNER",
|
| 507 |
-
step_type="PROPOSAL_2",
|
| 508 |
input_snapshot=planner_agent.get_input_snapshot(effort_result, penalties),
|
| 509 |
output_snapshot=planner_agent.get_output_snapshot(proposal2),
|
| 510 |
)
|
| 511 |
|
| 512 |
-
# Publish COMPLETED event
|
| 513 |
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "COMPLETED", {
|
| 514 |
-
"total_effort": proposal2.total_effort,
|
| 515 |
-
"num_assignments": len(proposal2.allocation),
|
| 516 |
-
"solver_status": proposal2.solver_status,
|
| 517 |
})
|
| 518 |
|
| 519 |
-
return {
|
| 520 |
-
"route_proposal_2": proposal_dict,
|
| 521 |
-
"decision_logs": state.decision_logs + [log_entry],
|
| 522 |
-
}
|
| 523 |
-
|
| 524 |
|
| 525 |
|
| 526 |
# =============================================================================
|
|
@@ -528,31 +373,19 @@ def route_planner_reoptimize_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 528 |
# =============================================================================
|
| 529 |
|
| 530 |
def select_final_proposal_node(state: AllocationState) -> Dict[str, Any]:
|
| 531 |
-
"""
|
| 532 |
-
Select the final proposal after fairness checks.
|
| 533 |
-
|
| 534 |
-
If proposal 2 exists and has better fairness, use it.
|
| 535 |
-
Otherwise, use proposal 1.
|
| 536 |
-
"""
|
| 537 |
final_proposal = state.route_proposal_1
|
| 538 |
final_fairness = state.fairness_check_1
|
| 539 |
|
| 540 |
if state.route_proposal_2 and state.fairness_check_2:
|
| 541 |
-
# Compare fairness metrics
|
| 542 |
check1_metrics = state.fairness_check_1["metrics"]
|
| 543 |
check2_metrics = state.fairness_check_2["metrics"]
|
| 544 |
-
|
| 545 |
-
# Use proposal 2 if it improves fairness
|
| 546 |
if (check2_metrics["gini_index"] <= check1_metrics["gini_index"] or
|
| 547 |
check2_metrics["max_gap"] < check1_metrics["max_gap"]):
|
| 548 |
final_proposal = state.route_proposal_2
|
| 549 |
final_fairness = state.fairness_check_2
|
| 550 |
|
| 551 |
-
return {
|
| 552 |
-
"final_proposal": final_proposal,
|
| 553 |
-
"final_fairness": final_fairness,
|
| 554 |
-
"final_per_driver_effort": final_proposal["per_driver_effort"],
|
| 555 |
-
}
|
| 556 |
|
| 557 |
|
| 558 |
# =============================================================================
|
|
@@ -560,64 +393,37 @@ def select_final_proposal_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 560 |
# =============================================================================
|
| 561 |
|
| 562 |
def driver_liaison_node(state: AllocationState) -> Dict[str, Any]:
|
| 563 |
-
"""
|
| 564 |
-
LangGraph node #6: Driver Liaison Agent.
|
| 565 |
-
|
| 566 |
-
Reviews proposed assignments and makes ACCEPT/COUNTER decisions per driver.
|
| 567 |
-
WRAPS EXISTING AGENT - no logic changes.
|
| 568 |
-
"""
|
| 569 |
run_id = state.allocation_run_id
|
| 570 |
|
| 571 |
-
|
| 572 |
-
_publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "STARTED", {
|
| 573 |
-
"num_drivers": len(state.driver_models),
|
| 574 |
-
})
|
| 575 |
|
| 576 |
from app.schemas.agent_schemas import AllocationItem
|
| 577 |
-
|
| 578 |
liaison_agent = DriverLiaisonAgent()
|
| 579 |
|
| 580 |
final_proposal = state.final_proposal or state.route_proposal_1
|
| 581 |
final_fairness = state.final_fairness or state.fairness_check_1
|
| 582 |
|
| 583 |
-
|
| 584 |
-
sorted_allocations = sorted(
|
| 585 |
-
final_proposal["allocation"],
|
| 586 |
-
key=lambda x: x["effort"],
|
| 587 |
-
reverse=True # Highest effort = rank 1
|
| 588 |
-
)
|
| 589 |
-
|
| 590 |
driver_proposals: List[DriverAssignmentProposal] = []
|
| 591 |
for rank, alloc_item in enumerate(sorted_allocations, start=1):
|
| 592 |
driver_proposals.append(DriverAssignmentProposal(
|
| 593 |
-
driver_id=str(alloc_item["driver_id"]),
|
| 594 |
-
|
| 595 |
-
effort=alloc_item["effort"],
|
| 596 |
-
rank_in_team=rank,
|
| 597 |
))
|
| 598 |
|
| 599 |
-
# Get global metrics
|
| 600 |
metrics = final_fairness["metrics"]
|
| 601 |
-
global_avg_effort = metrics["avg_effort"]
|
| 602 |
-
global_std_effort = metrics["std_dev"]
|
| 603 |
-
|
| 604 |
-
# Build DriverContext objects
|
| 605 |
driver_context_objs: Dict[str, DriverContext] = {}
|
| 606 |
-
for driver_id, context_dict in state.driver_contexts.items():
|
| 607 |
driver_context_objs[driver_id] = DriverContext(**context_dict)
|
| 608 |
|
| 609 |
-
# Run liaison for all drivers (EXISTING CODE - UNCHANGED)
|
| 610 |
negotiation_result = liaison_agent.run_for_all_drivers(
|
| 611 |
-
proposals=driver_proposals,
|
| 612 |
-
|
| 613 |
-
effort_matrix=state.effort_matrix["matrix"],
|
| 614 |
-
driver_ids=state.effort_matrix["driver_ids"],
|
| 615 |
route_ids=state.effort_matrix["route_ids"],
|
| 616 |
-
global_avg_effort=
|
| 617 |
-
global_std_effort=global_std_effort,
|
| 618 |
)
|
| 619 |
|
| 620 |
-
# Serialize result
|
| 621 |
liaison_dict = {
|
| 622 |
"decisions": [d.model_dump() if hasattr(d, 'model_dump') else d for d in negotiation_result.decisions],
|
| 623 |
"num_accept": negotiation_result.num_accept,
|
|
@@ -625,29 +431,17 @@ def driver_liaison_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 625 |
"num_force_accept": negotiation_result.num_force_accept,
|
| 626 |
}
|
| 627 |
|
| 628 |
-
# Create decision log
|
| 629 |
log_entry = _create_decision_log(
|
| 630 |
-
agent_name="DRIVER_LIAISON",
|
| 631 |
-
|
| 632 |
-
input_snapshot=liaison_agent.get_input_snapshot(
|
| 633 |
-
driver_proposals,
|
| 634 |
-
global_avg_effort,
|
| 635 |
-
global_std_effort,
|
| 636 |
-
),
|
| 637 |
output_snapshot=liaison_agent.get_output_snapshot(negotiation_result),
|
| 638 |
)
|
| 639 |
|
| 640 |
-
# Publish COMPLETED event
|
| 641 |
_publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "COMPLETED", {
|
| 642 |
-
"num_accept": negotiation_result.num_accept,
|
| 643 |
-
"num_counter": negotiation_result.num_counter,
|
| 644 |
-
"num_force_accept": negotiation_result.num_force_accept,
|
| 645 |
})
|
| 646 |
|
| 647 |
-
return {
|
| 648 |
-
"liaison_feedback": liaison_dict,
|
| 649 |
-
"decision_logs": state.decision_logs + [log_entry],
|
| 650 |
-
}
|
| 651 |
|
| 652 |
|
| 653 |
# =============================================================================
|
|
@@ -655,101 +449,61 @@ def driver_liaison_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 655 |
# =============================================================================
|
| 656 |
|
| 657 |
def final_resolution_node(state: AllocationState) -> Dict[str, Any]:
|
| 658 |
-
"""
|
| 659 |
-
LangGraph node #7: Final Resolution Agent.
|
| 660 |
-
|
| 661 |
-
Resolves COUNTER decisions through swaps.
|
| 662 |
-
WRAPS EXISTING AGENT - no logic changes.
|
| 663 |
-
"""
|
| 664 |
run_id = state.allocation_run_id
|
| 665 |
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem, FairnessMetrics, DriverLiaisonDecision
|
| 666 |
|
| 667 |
-
|
| 668 |
-
counter_decisions = [
|
| 669 |
-
d for d in state.liaison_feedback["decisions"]
|
| 670 |
-
if d["decision"] == "COUNTER"
|
| 671 |
-
]
|
| 672 |
|
| 673 |
if not counter_decisions:
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
"reason": "no_counters",
|
| 677 |
-
"swaps_applied": 0,
|
| 678 |
-
})
|
| 679 |
-
# No resolution needed
|
| 680 |
-
return {
|
| 681 |
-
"resolution_result": {"swaps_applied": []},
|
| 682 |
-
}
|
| 683 |
|
| 684 |
-
|
| 685 |
-
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "STARTED", {
|
| 686 |
-
"num_counters": len(counter_decisions),
|
| 687 |
-
})
|
| 688 |
|
| 689 |
resolution_agent = FinalResolutionAgent()
|
| 690 |
-
|
| 691 |
-
# Reconstruct objects for resolution
|
| 692 |
final_proposal = state.final_proposal or state.route_proposal_1
|
| 693 |
final_fairness = state.final_fairness or state.fairness_check_1
|
| 694 |
|
| 695 |
approved_proposal = RoutePlanResult(
|
| 696 |
allocation=[AllocationItem(**a) for a in final_proposal["allocation"]],
|
| 697 |
total_effort=final_proposal["total_effort"],
|
| 698 |
-
avg_effort=final_proposal.get("avg_effort", final_proposal["total_effort"] / len(final_proposal["allocation"])
|
| 699 |
solver_status=final_proposal.get("solver_status", "OPTIMAL"),
|
| 700 |
proposal_number=final_proposal["proposal_number"],
|
| 701 |
per_driver_effort=final_proposal["per_driver_effort"],
|
| 702 |
)
|
| 703 |
|
| 704 |
decisions = [DriverLiaisonDecision(**d) for d in state.liaison_feedback["decisions"]]
|
| 705 |
-
|
| 706 |
current_metrics = FairnessMetrics(**final_fairness["metrics"])
|
| 707 |
|
| 708 |
-
# Resolve counters (EXISTING CODE - UNCHANGED)
|
| 709 |
resolution_result = resolution_agent.resolve_counters(
|
| 710 |
-
approved_proposal=approved_proposal,
|
| 711 |
-
decisions=decisions,
|
| 712 |
effort_matrix=state.effort_matrix["matrix"],
|
| 713 |
-
driver_ids=state.effort_matrix["driver_ids"],
|
| 714 |
-
route_ids=state.effort_matrix["route_ids"],
|
| 715 |
current_metrics=current_metrics,
|
| 716 |
)
|
| 717 |
|
| 718 |
-
# Serialize result
|
| 719 |
resolution_dict = {
|
| 720 |
"swaps_applied": [s.model_dump() if hasattr(s, 'model_dump') else s for s in resolution_result.swaps_applied],
|
| 721 |
-
"allocation": resolution_result.allocation,
|
| 722 |
"per_driver_effort": resolution_result.per_driver_effort,
|
| 723 |
"metrics": resolution_result.metrics,
|
| 724 |
}
|
| 725 |
|
| 726 |
-
# Create decision log
|
| 727 |
log_entry = _create_decision_log(
|
| 728 |
-
agent_name="
|
| 729 |
-
|
| 730 |
-
input_snapshot=resolution_agent.get_input_snapshot(
|
| 731 |
-
len(counter_decisions),
|
| 732 |
-
current_metrics,
|
| 733 |
-
final_fairness["metrics"]["avg_effort"],
|
| 734 |
-
),
|
| 735 |
output_snapshot=resolution_agent.get_output_snapshot(resolution_result),
|
| 736 |
)
|
| 737 |
|
| 738 |
-
|
| 739 |
-
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {
|
| 740 |
-
"swaps_applied": len(resolution_result.swaps_applied),
|
| 741 |
-
})
|
| 742 |
|
| 743 |
-
|
| 744 |
-
updated_effort = state.final_per_driver_effort.copy()
|
| 745 |
if resolution_result.swaps_applied:
|
| 746 |
updated_effort = resolution_result.per_driver_effort
|
| 747 |
|
| 748 |
-
return {
|
| 749 |
-
"resolution_result": resolution_dict,
|
| 750 |
-
"final_per_driver_effort": updated_effort,
|
| 751 |
-
"decision_logs": state.decision_logs + [log_entry],
|
| 752 |
-
}
|
| 753 |
|
| 754 |
|
| 755 |
# =============================================================================
|
|
@@ -757,21 +511,12 @@ def final_resolution_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 757 |
# =============================================================================
|
| 758 |
|
| 759 |
def explainability_node(state: AllocationState) -> Dict[str, Any]:
|
| 760 |
-
"""
|
| 761 |
-
LangGraph node #8: Explainability Agent.
|
| 762 |
-
|
| 763 |
-
Generates template-based explanations for each driver.
|
| 764 |
-
WRAPS EXISTING AGENT - no logic changes.
|
| 765 |
-
"""
|
| 766 |
run_id = state.allocation_run_id
|
| 767 |
|
| 768 |
-
|
| 769 |
-
_publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "STARTED", {
|
| 770 |
-
"num_drivers": len(state.driver_models),
|
| 771 |
-
})
|
| 772 |
|
| 773 |
explain_agent = ExplainabilityAgent()
|
| 774 |
-
|
| 775 |
final_proposal = state.final_proposal or state.route_proposal_1
|
| 776 |
final_fairness = state.final_fairness or state.fairness_check_1
|
| 777 |
final_per_driver_effort = state.final_per_driver_effort or final_proposal["per_driver_effort"]
|
|
@@ -779,32 +524,24 @@ def explainability_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 779 |
metrics = final_fairness["metrics"]
|
| 780 |
avg_effort = metrics["avg_effort"]
|
| 781 |
|
| 782 |
-
# Build lookup structures
|
| 783 |
route_by_id = {str(r["id"]): r for r in state.route_models}
|
| 784 |
driver_by_id = {str(d["id"]): d for d in state.driver_models}
|
| 785 |
-
route_dict_by_id = {str(r["id"]): rd for r, rd in zip(state.route_models, state.route_dicts)}
|
| 786 |
|
| 787 |
-
|
| 788 |
-
sorted_efforts = sorted(
|
| 789 |
-
final_per_driver_effort.items(),
|
| 790 |
-
key=lambda x: x[1],
|
| 791 |
-
reverse=True
|
| 792 |
-
)
|
| 793 |
rank_by_driver = {did: idx + 1 for idx, (did, _) in enumerate(sorted_efforts)}
|
| 794 |
num_drivers = len(final_per_driver_effort)
|
| 795 |
|
| 796 |
-
# Build liaison decisions lookup
|
| 797 |
liaison_by_driver = {}
|
| 798 |
if state.liaison_feedback:
|
| 799 |
for decision in state.liaison_feedback["decisions"]:
|
| 800 |
liaison_by_driver[decision["driver_id"]] = decision
|
| 801 |
|
| 802 |
-
# Build swaps lookup
|
| 803 |
swapped_drivers = set()
|
| 804 |
if state.resolution_result and state.resolution_result.get("swaps_applied"):
|
| 805 |
for swap in state.resolution_result["swaps_applied"]:
|
| 806 |
-
swapped_drivers.add(swap
|
| 807 |
-
swapped_drivers.add(swap
|
| 808 |
|
| 809 |
explanations: Dict[str, Dict[str, Any]] = {}
|
| 810 |
category_counts: Dict[str, int] = {}
|
|
@@ -815,18 +552,13 @@ def explainability_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 815 |
|
| 816 |
driver = driver_by_id.get(driver_id_str, {})
|
| 817 |
route = route_by_id.get(route_id_str, {})
|
| 818 |
-
route_dict = route_dict_by_id.get(route_id_str, {})
|
| 819 |
|
| 820 |
-
# Use resolved effort if available
|
| 821 |
effort = final_per_driver_effort.get(driver_id_str, alloc_item["effort"])
|
| 822 |
fairness_score = calculate_fairness_score(effort, avg_effort)
|
| 823 |
-
|
| 824 |
-
# Get driver context
|
| 825 |
-
driver_context = state.driver_contexts.get(driver_id_str, {})
|
| 826 |
history_efforts = [driver_context.get("recent_avg_effort", avg_effort)] if driver_context else []
|
| 827 |
history_hard_days = driver_context.get("recent_hard_days", 0) if driver_context else 0
|
| 828 |
|
| 829 |
-
# Get effort breakdown
|
| 830 |
breakdown_key = f"{driver_id_str}:{route_id_str}"
|
| 831 |
effort_breakdown_data = state.effort_matrix.get("breakdown", {}).get(breakdown_key, {})
|
| 832 |
effort_breakdown = {
|
|
@@ -835,82 +567,42 @@ def explainability_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 835 |
"time_pressure": effort_breakdown_data.get("time_pressure", 0),
|
| 836 |
}
|
| 837 |
|
| 838 |
-
# Get liaison decision
|
| 839 |
liaison_decision = liaison_by_driver.get(driver_id_str)
|
|
|
|
| 840 |
|
| 841 |
-
# Determine if recovery day
|
| 842 |
-
is_recovery = (
|
| 843 |
-
history_hard_days >= 3 and
|
| 844 |
-
effort < avg_effort * 0.85
|
| 845 |
-
)
|
| 846 |
-
|
| 847 |
-
# Build explanation input
|
| 848 |
explain_input = DriverExplanationInput(
|
| 849 |
-
driver_id=driver_id_str,
|
| 850 |
-
|
| 851 |
-
num_drivers=num_drivers,
|
| 852 |
-
today_effort=effort,
|
| 853 |
today_rank=rank_by_driver.get(driver_id_str, num_drivers),
|
| 854 |
route_id=route_id_str,
|
| 855 |
-
route_summary={
|
| 856 |
-
"num_packages": route.get("num_packages", 0),
|
| 857 |
-
"total_weight_kg": route.get("total_weight_kg", 0),
|
| 858 |
-
"num_stops": route.get("num_stops", 0),
|
| 859 |
-
"difficulty_score": route.get("route_difficulty_score", 0),
|
| 860 |
-
"estimated_time_minutes": route.get("estimated_time_minutes", 0),
|
| 861 |
-
},
|
| 862 |
effort_breakdown=effort_breakdown,
|
| 863 |
-
global_avg_effort=avg_effort,
|
| 864 |
-
|
| 865 |
-
global_gini_index=metrics["gini_index"],
|
| 866 |
-
global_max_gap=metrics["max_gap"],
|
| 867 |
history_efforts_last_7_days=history_efforts,
|
| 868 |
-
history_hard_days_last_7=history_hard_days,
|
| 869 |
-
|
| 870 |
-
had_manual_override=False, # TODO: Query DB if needed
|
| 871 |
liaison_decision=liaison_decision["decision"] if liaison_decision else None,
|
| 872 |
swap_applied=driver_id_str in swapped_drivers,
|
| 873 |
)
|
| 874 |
|
| 875 |
-
# Generate explanations (EXISTING CODE - UNCHANGED)
|
| 876 |
explain_output = explain_agent.build_explanation_for_driver(explain_input)
|
| 877 |
-
|
| 878 |
-
# Track category counts
|
| 879 |
category_counts[explain_output.category] = category_counts.get(explain_output.category, 0) + 1
|
| 880 |
-
|
| 881 |
explanations[driver_id_str] = {
|
| 882 |
"driver_explanation": explain_output.driver_explanation,
|
| 883 |
"admin_explanation": explain_output.admin_explanation,
|
| 884 |
"category": explain_output.category,
|
| 885 |
}
|
| 886 |
|
| 887 |
-
# Create decision log
|
| 888 |
log_entry = _create_decision_log(
|
| 889 |
-
agent_name="EXPLAINABILITY",
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
num_drivers=num_drivers,
|
| 893 |
-
avg_effort=avg_effort,
|
| 894 |
-
std_effort=metrics["std_dev"],
|
| 895 |
-
gini_index=metrics["gini_index"],
|
| 896 |
-
category_counts=category_counts,
|
| 897 |
-
),
|
| 898 |
-
output_snapshot=explain_agent.get_output_snapshot(
|
| 899 |
-
total_explanations=len(explanations),
|
| 900 |
-
category_counts=category_counts,
|
| 901 |
-
),
|
| 902 |
)
|
| 903 |
|
| 904 |
-
|
| 905 |
-
_publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "COMPLETED", {
|
| 906 |
-
"total_explanations": len(explanations),
|
| 907 |
-
"categories": category_counts,
|
| 908 |
-
})
|
| 909 |
|
| 910 |
-
return {
|
| 911 |
-
"explanations": explanations,
|
| 912 |
-
"decision_logs": state.decision_logs + [log_entry],
|
| 913 |
-
}
|
| 914 |
|
| 915 |
|
| 916 |
# =============================================================================
|
|
@@ -918,13 +610,7 @@ def explainability_node(state: AllocationState) -> Dict[str, Any]:
|
|
| 918 |
# =============================================================================
|
| 919 |
|
| 920 |
def should_reoptimize(state: AllocationState) -> str:
|
| 921 |
-
"""
|
| 922 |
-
Conditional edge: decide if re-optimization is needed.
|
| 923 |
-
|
| 924 |
-
Returns:
|
| 925 |
-
"reoptimize" - if fairness check 1 says REOPTIMIZE and no proposal 2 yet
|
| 926 |
-
"continue" - otherwise
|
| 927 |
-
"""
|
| 928 |
if state.fairness_check_1 and state.fairness_check_1.get("status") == "REOPTIMIZE":
|
| 929 |
if not state.route_proposal_2:
|
| 930 |
return "reoptimize"
|
|
@@ -932,18 +618,8 @@ def should_reoptimize(state: AllocationState) -> str:
|
|
| 932 |
|
| 933 |
|
| 934 |
def has_counter_decisions(state: AllocationState) -> str:
|
| 935 |
-
"""
|
| 936 |
-
Conditional edge: check if any COUNTER decisions need resolution.
|
| 937 |
-
|
| 938 |
-
Returns:
|
| 939 |
-
"resolve" - if there are COUNTER decisions
|
| 940 |
-
"skip" - otherwise
|
| 941 |
-
"""
|
| 942 |
if state.liaison_feedback:
|
| 943 |
-
|
| 944 |
-
1 for d in state.liaison_feedback["decisions"]
|
| 945 |
-
if d["decision"] == "COUNTER"
|
| 946 |
-
)
|
| 947 |
-
if counter_count > 0:
|
| 948 |
return "resolve"
|
| 949 |
return "skip"
|
|
|
|
| 1 |
"""
|
| 2 |
LangGraph node wrappers for Fair Dispatch agents.
|
| 3 |
Each node wraps an existing agent with minimal changes, preserving the original logic.
|
| 4 |
+
|
| 5 |
+
PRODUCTION FIXES APPLIED:
|
| 6 |
+
- ModelWrapper.is_ev: handles "EV", "ELECTRIC", VehicleType.EV enum values
|
| 7 |
+
- _publish_event_sync: improved reliability with asyncio.ensure_future
|
| 8 |
"""
|
| 9 |
|
| 10 |
from datetime import datetime
|
| 11 |
from typing import Dict, Any, List, Optional, Tuple
|
| 12 |
from uuid import UUID
|
| 13 |
import asyncio
|
| 14 |
+
import logging
|
| 15 |
|
| 16 |
from app.schemas.allocation_state import AllocationState
|
| 17 |
from app.schemas.agent_schemas import (
|
|
|
|
| 29 |
from app.services.fairness import calculate_fairness_score
|
| 30 |
from app.core.events import agent_event_bus, make_agent_event
|
| 31 |
|
| 32 |
+
logger = logging.getLogger("fairrelay.langgraph")
|
| 33 |
+
|
| 34 |
|
| 35 |
class ModelWrapper:
|
| 36 |
"""Helper to wrap dicts as objects for agent compatibility."""
|
|
|
|
| 42 |
|
| 43 |
@property
|
| 44 |
def is_ev(self) -> bool:
|
| 45 |
+
"""Check if driver has an EV - handles all possible enum/string formats."""
|
| 46 |
+
vt = self._data.get("vehicle_type", "")
|
| 47 |
+
vt_str = str(vt).upper()
|
| 48 |
+
return vt_str in ("EV", "ELECTRIC", "VEHICLETYPE.EV")
|
| 49 |
|
| 50 |
|
| 51 |
def _publish_event_sync(
|
|
|
|
| 58 |
"""
|
| 59 |
Publish an agent event synchronously (fire-and-forget).
|
| 60 |
Used by LangGraph nodes which are synchronous functions.
|
| 61 |
+
|
| 62 |
+
Uses asyncio.ensure_future for reliable delivery when a loop is running.
|
| 63 |
"""
|
| 64 |
if not allocation_run_id:
|
| 65 |
return
|
|
|
|
| 72 |
payload=payload,
|
| 73 |
)
|
| 74 |
|
| 75 |
+
# Schedule async publish on the running event loop
|
| 76 |
try:
|
| 77 |
loop = asyncio.get_running_loop()
|
| 78 |
+
asyncio.ensure_future(agent_event_bus.publish(event), loop=loop)
|
| 79 |
except RuntimeError:
|
| 80 |
+
# No running loop — this shouldn't happen in FastAPI context
|
| 81 |
+
# but handle gracefully for testing
|
| 82 |
+
try:
|
| 83 |
+
asyncio.run(agent_event_bus.publish(event))
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.warning(f"Failed to publish agent event: {e}")
|
| 86 |
|
| 87 |
|
| 88 |
def _create_decision_log(
|
|
|
|
| 108 |
def ml_effort_node(state: AllocationState) -> Dict[str, Any]:
|
| 109 |
"""
|
| 110 |
LangGraph node #1: ML Effort Agent.
|
|
|
|
| 111 |
Computes effort matrix for all driver-route pairs using MLEffortAgent.
|
|
|
|
| 112 |
"""
|
| 113 |
run_id = state.allocation_run_id
|
| 114 |
|
|
|
|
| 115 |
_publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "STARTED", {
|
| 116 |
"num_drivers": len(state.driver_models),
|
| 117 |
"num_routes": len(state.route_models),
|
| 118 |
})
|
| 119 |
|
|
|
|
| 120 |
ml_agent = MLEffortAgent()
|
| 121 |
|
|
|
|
| 122 |
ev_config = {
|
| 123 |
"safety_margin_pct": state.config_used.get("ev_safety_margin_pct", 10.0) if state.config_used else 10.0,
|
| 124 |
"charging_penalty_weight": state.config_used.get("ev_charging_penalty_weight", 0.3) if state.config_used else 0.3,
|
| 125 |
}
|
| 126 |
|
|
|
|
| 127 |
drivers = [ModelWrapper(d) for d in state.driver_models]
|
| 128 |
routes = [ModelWrapper(r) for r in state.route_models]
|
| 129 |
|
| 130 |
+
effort_result = ml_agent.compute_effort_matrix(drivers=drivers, routes=routes, ev_config=ev_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
|
|
|
| 132 |
effort_dict = {
|
| 133 |
"matrix": effort_result.matrix,
|
| 134 |
"driver_ids": effort_result.driver_ids,
|
| 135 |
"route_ids": effort_result.route_ids,
|
| 136 |
+
"breakdown": {k: v.model_dump() if hasattr(v, 'model_dump') else v for k, v in effort_result.breakdown.items()},
|
|
|
|
| 137 |
"stats": effort_result.stats,
|
| 138 |
"infeasible_pairs": list(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else [],
|
| 139 |
}
|
| 140 |
|
|
|
|
| 141 |
log_entry = _create_decision_log(
|
| 142 |
+
agent_name="ML_EFFORT", step_type="MATRIX_GENERATION",
|
|
|
|
| 143 |
input_snapshot=ml_agent.get_input_snapshot(drivers, routes),
|
| 144 |
+
output_snapshot={**ml_agent.get_output_snapshot(effort_result), "num_infeasible_ev_pairs": len(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else 0},
|
|
|
|
|
|
|
|
|
|
| 145 |
)
|
| 146 |
|
|
|
|
| 147 |
_publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "COMPLETED", {
|
| 148 |
"min_effort": effort_result.stats.get("min", 0),
|
| 149 |
"max_effort": effort_result.stats.get("max", 0),
|
| 150 |
"avg_effort": effort_result.stats.get("avg", 0),
|
| 151 |
})
|
| 152 |
|
| 153 |
+
return {"effort_matrix": effort_dict, "decision_logs": state.decision_logs + [log_entry]}
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
# =============================================================================
|
|
|
|
| 158 |
# =============================================================================
|
| 159 |
|
| 160 |
def route_planner_node(state: AllocationState) -> Dict[str, Any]:
|
| 161 |
+
"""LangGraph node #2: Route Planner Agent - Proposal 1 (OR-Tools optimization)."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
run_id = state.allocation_run_id
|
| 163 |
|
|
|
|
| 164 |
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "STARTED", {
|
| 165 |
+
"num_drivers": len(state.driver_models), "num_routes": len(state.route_models),
|
|
|
|
| 166 |
})
|
| 167 |
|
| 168 |
planner_agent = RoutePlannerAgent()
|
| 169 |
+
from app.schemas.agent_schemas import EffortMatrixResult
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
matrix = state.effort_matrix["matrix"]
|
| 172 |
+
stats = state.effort_matrix.get("stats") or {"min": 0, "max": 0, "avg": 0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
effort_result = EffortMatrixResult(
|
| 175 |
+
matrix=matrix, driver_ids=state.effort_matrix["driver_ids"],
|
| 176 |
+
route_ids=state.effort_matrix["route_ids"], breakdown={}, stats=stats,
|
|
|
|
|
|
|
|
|
|
| 177 |
infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
|
| 178 |
)
|
| 179 |
|
|
|
|
| 180 |
recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
|
|
|
|
|
|
|
| 181 |
drivers = [ModelWrapper(d) for d in state.driver_models]
|
| 182 |
routes = [ModelWrapper(r) for r in state.route_models]
|
| 183 |
|
|
|
|
| 184 |
proposal1 = planner_agent.plan(
|
| 185 |
+
effort_result=effort_result, drivers=drivers, routes=routes,
|
|
|
|
|
|
|
| 186 |
recovery_targets=state.recovery_targets or {},
|
| 187 |
+
recovery_penalty_weight=recovery_penalty_weight, proposal_number=1,
|
|
|
|
| 188 |
)
|
| 189 |
|
|
|
|
| 190 |
proposal_dict = {
|
| 191 |
"allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal1.allocation],
|
| 192 |
+
"total_effort": proposal1.total_effort, "avg_effort": proposal1.avg_effort,
|
| 193 |
+
"solver_status": proposal1.solver_status, "proposal_number": proposal1.proposal_number,
|
|
|
|
|
|
|
| 194 |
"per_driver_effort": proposal1.per_driver_effort,
|
| 195 |
}
|
| 196 |
|
|
|
|
| 197 |
log_entry = _create_decision_log(
|
| 198 |
+
agent_name="ROUTE_PLANNER", step_type="PROPOSAL_1",
|
|
|
|
| 199 |
input_snapshot=planner_agent.get_input_snapshot(effort_result),
|
| 200 |
output_snapshot=planner_agent.get_output_snapshot(proposal1),
|
| 201 |
)
|
| 202 |
|
|
|
|
| 203 |
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "COMPLETED", {
|
| 204 |
+
"total_effort": proposal1.total_effort, "num_assignments": len(proposal1.allocation),
|
|
|
|
| 205 |
"solver_status": proposal1.solver_status,
|
| 206 |
})
|
| 207 |
|
| 208 |
+
return {"route_proposal_1": proposal_dict, "decision_logs": state.decision_logs + [log_entry]}
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
|
| 211 |
# =============================================================================
|
|
|
|
| 213 |
# =============================================================================
|
| 214 |
|
| 215 |
def fairness_check_node(state: AllocationState) -> Dict[str, Any]:
|
| 216 |
+
"""LangGraph node #3: Fairness Manager Agent — evaluates Gini/stddev/max_gap."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
run_id = state.allocation_run_id
|
|
|
|
| 218 |
|
| 219 |
+
_publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_1", "STARTED", {"proposal_number": 1})
|
|
|
|
|
|
|
|
|
|
| 220 |
|
|
|
|
| 221 |
thresholds = FairnessThresholds(
|
| 222 |
gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
|
| 223 |
stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
|
|
|
|
| 225 |
)
|
| 226 |
|
| 227 |
fairness_agent = FairnessManagerAgent(thresholds=thresholds)
|
|
|
|
|
|
|
| 228 |
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
|
| 229 |
|
| 230 |
+
proposal_to_check = state.route_proposal_1
|
|
|
|
|
|
|
| 231 |
plan_result = RoutePlanResult(
|
| 232 |
allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
|
| 233 |
total_effort=proposal_to_check["total_effort"],
|
| 234 |
+
avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / max(len(proposal_to_check["allocation"]), 1)),
|
| 235 |
solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
|
| 236 |
+
proposal_number=1, per_driver_effort=proposal_to_check["per_driver_effort"],
|
|
|
|
| 237 |
)
|
| 238 |
|
| 239 |
+
fairness_result = fairness_agent.check(plan_result, proposal_number=1)
|
|
|
|
| 240 |
|
|
|
|
| 241 |
fairness_dict = {
|
| 242 |
+
"status": fairness_result.status, "proposal_number": fairness_result.proposal_number,
|
|
|
|
| 243 |
"metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
|
| 244 |
+
"avg_effort": fairness_result.metrics.avg_effort, "std_dev": fairness_result.metrics.std_dev,
|
| 245 |
+
"gini_index": fairness_result.metrics.gini_index, "max_effort": fairness_result.metrics.max_effort,
|
| 246 |
+
"min_effort": fairness_result.metrics.min_effort, "max_gap": fairness_result.metrics.max_gap,
|
|
|
|
|
|
|
|
|
|
| 247 |
},
|
| 248 |
"recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
|
| 249 |
}
|
| 250 |
|
|
|
|
| 251 |
log_entry = _create_decision_log(
|
| 252 |
+
agent_name="FAIRNESS_MANAGER", step_type="FAIRNESS_CHECK_PROPOSAL_1",
|
|
|
|
| 253 |
input_snapshot=fairness_agent.get_input_snapshot(plan_result),
|
| 254 |
output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
|
| 255 |
)
|
| 256 |
|
| 257 |
+
_publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_1", "COMPLETED", {
|
| 258 |
+
"status": fairness_result.status, "gini_index": fairness_dict["metrics"]["gini_index"],
|
|
|
|
|
|
|
|
|
|
| 259 |
})
|
| 260 |
|
| 261 |
+
return {"fairness_check_1": fairness_dict, "decision_logs": state.decision_logs + [log_entry]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
|
| 264 |
def fairness_check_2_node(state: AllocationState) -> Dict[str, Any]:
|
| 265 |
+
"""LangGraph node for second fairness check (after re-optimization)."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
run_id = state.allocation_run_id
|
|
|
|
| 267 |
|
| 268 |
+
_publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_2", "STARTED", {"proposal_number": 2})
|
|
|
|
|
|
|
|
|
|
| 269 |
|
|
|
|
| 270 |
thresholds = FairnessThresholds(
|
| 271 |
gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
|
| 272 |
stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
|
|
|
|
| 274 |
)
|
| 275 |
|
| 276 |
fairness_agent = FairnessManagerAgent(thresholds=thresholds)
|
|
|
|
|
|
|
| 277 |
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
|
| 278 |
|
|
|
|
| 279 |
proposal_to_check = state.route_proposal_2
|
|
|
|
| 280 |
plan_result = RoutePlanResult(
|
| 281 |
allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
|
| 282 |
total_effort=proposal_to_check["total_effort"],
|
| 283 |
+
avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / max(len(proposal_to_check["allocation"]), 1)),
|
| 284 |
solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
|
| 285 |
+
proposal_number=2, per_driver_effort=proposal_to_check["per_driver_effort"],
|
|
|
|
| 286 |
)
|
| 287 |
|
| 288 |
+
fairness_result = fairness_agent.check(plan_result, proposal_number=2)
|
|
|
|
| 289 |
|
|
|
|
| 290 |
fairness_dict = {
|
| 291 |
+
"status": fairness_result.status, "proposal_number": 2,
|
|
|
|
| 292 |
"metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
|
| 293 |
+
"avg_effort": fairness_result.metrics.avg_effort, "std_dev": fairness_result.metrics.std_dev,
|
| 294 |
+
"gini_index": fairness_result.metrics.gini_index, "max_effort": fairness_result.metrics.max_effort,
|
| 295 |
+
"min_effort": fairness_result.metrics.min_effort, "max_gap": fairness_result.metrics.max_gap,
|
|
|
|
|
|
|
|
|
|
| 296 |
},
|
| 297 |
"recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
|
| 298 |
}
|
| 299 |
|
|
|
|
| 300 |
log_entry = _create_decision_log(
|
| 301 |
+
agent_name="FAIRNESS_MANAGER", step_type="FAIRNESS_CHECK_PROPOSAL_2",
|
|
|
|
| 302 |
input_snapshot=fairness_agent.get_input_snapshot(plan_result),
|
| 303 |
output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
|
| 304 |
)
|
| 305 |
|
| 306 |
+
_publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_2", "COMPLETED", {
|
| 307 |
+
"status": fairness_result.status, "gini_index": fairness_dict["metrics"]["gini_index"],
|
|
|
|
|
|
|
|
|
|
| 308 |
})
|
| 309 |
|
| 310 |
+
return {"fairness_check_2": fairness_dict, "decision_logs": state.decision_logs + [log_entry]}
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
|
| 313 |
# =============================================================================
|
| 314 |
+
# Node 4: Route Planner Re-optimization (Proposal 2)
|
| 315 |
# =============================================================================
|
| 316 |
|
|
|
|
| 317 |
def route_planner_reoptimize_node(state: AllocationState) -> Dict[str, Any]:
|
| 318 |
+
"""LangGraph node #4: Route Planner - Proposal 2 with fairness penalties."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
run_id = state.allocation_run_id
|
| 320 |
|
| 321 |
+
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "STARTED", {"reason": "fairness_reoptimization"})
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
planner_agent = RoutePlannerAgent()
|
|
|
|
|
|
|
| 324 |
from app.schemas.agent_schemas import EffortMatrixResult, FairnessRecommendations
|
| 325 |
|
|
|
|
| 326 |
matrix = state.effort_matrix["matrix"]
|
| 327 |
+
stats = state.effort_matrix.get("stats") or {"min": 0, "max": 0, "avg": 0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
effort_result = EffortMatrixResult(
|
| 330 |
+
matrix=matrix, driver_ids=state.effort_matrix["driver_ids"],
|
| 331 |
+
route_ids=state.effort_matrix["route_ids"], breakdown={}, stats=stats,
|
|
|
|
|
|
|
|
|
|
| 332 |
infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
|
| 333 |
)
|
| 334 |
|
|
|
|
| 335 |
recommendations_dict = state.fairness_check_1.get("recommendations")
|
| 336 |
penalties = {}
|
|
|
|
| 337 |
if recommendations_dict:
|
| 338 |
recommendations = FairnessRecommendations(**recommendations_dict)
|
| 339 |
+
penalties = planner_agent.build_penalties_from_recommendations(recommendations, state.route_proposal_1["per_driver_effort"])
|
|
|
|
|
|
|
|
|
|
| 340 |
|
|
|
|
| 341 |
recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
|
|
|
|
|
|
|
| 342 |
drivers = [ModelWrapper(d) for d in state.driver_models]
|
| 343 |
routes = [ModelWrapper(r) for r in state.route_models]
|
| 344 |
|
|
|
|
| 345 |
proposal2 = planner_agent.plan(
|
| 346 |
+
effort_result=effort_result, drivers=drivers, routes=routes,
|
| 347 |
+
fairness_penalties=penalties, recovery_targets=state.recovery_targets or {},
|
| 348 |
+
recovery_penalty_weight=recovery_penalty_weight, proposal_number=2,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
)
|
| 350 |
|
|
|
|
| 351 |
proposal_dict = {
|
| 352 |
"allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal2.allocation],
|
| 353 |
+
"total_effort": proposal2.total_effort, "avg_effort": proposal2.avg_effort,
|
| 354 |
+
"solver_status": proposal2.solver_status, "proposal_number": 2,
|
|
|
|
|
|
|
| 355 |
"per_driver_effort": proposal2.per_driver_effort,
|
| 356 |
}
|
| 357 |
|
|
|
|
| 358 |
log_entry = _create_decision_log(
|
| 359 |
+
agent_name="ROUTE_PLANNER", step_type="PROPOSAL_2",
|
|
|
|
| 360 |
input_snapshot=planner_agent.get_input_snapshot(effort_result, penalties),
|
| 361 |
output_snapshot=planner_agent.get_output_snapshot(proposal2),
|
| 362 |
)
|
| 363 |
|
|
|
|
| 364 |
_publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "COMPLETED", {
|
| 365 |
+
"total_effort": proposal2.total_effort, "solver_status": proposal2.solver_status,
|
|
|
|
|
|
|
| 366 |
})
|
| 367 |
|
| 368 |
+
return {"route_proposal_2": proposal_dict, "decision_logs": state.decision_logs + [log_entry]}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
|
| 371 |
# =============================================================================
|
|
|
|
| 373 |
# =============================================================================
|
| 374 |
|
| 375 |
def select_final_proposal_node(state: AllocationState) -> Dict[str, Any]:
|
| 376 |
+
"""Select best proposal based on fairness metrics comparison."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
final_proposal = state.route_proposal_1
|
| 378 |
final_fairness = state.fairness_check_1
|
| 379 |
|
| 380 |
if state.route_proposal_2 and state.fairness_check_2:
|
|
|
|
| 381 |
check1_metrics = state.fairness_check_1["metrics"]
|
| 382 |
check2_metrics = state.fairness_check_2["metrics"]
|
|
|
|
|
|
|
| 383 |
if (check2_metrics["gini_index"] <= check1_metrics["gini_index"] or
|
| 384 |
check2_metrics["max_gap"] < check1_metrics["max_gap"]):
|
| 385 |
final_proposal = state.route_proposal_2
|
| 386 |
final_fairness = state.fairness_check_2
|
| 387 |
|
| 388 |
+
return {"final_proposal": final_proposal, "final_fairness": final_fairness, "final_per_driver_effort": final_proposal["per_driver_effort"]}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
|
| 391 |
# =============================================================================
|
|
|
|
| 393 |
# =============================================================================
|
| 394 |
|
| 395 |
def driver_liaison_node(state: AllocationState) -> Dict[str, Any]:
|
| 396 |
+
"""LangGraph node #6: Driver Liaison - per-driver comfort band negotiation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
run_id = state.allocation_run_id
|
| 398 |
|
| 399 |
+
_publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "STARTED", {"num_drivers": len(state.driver_models)})
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
from app.schemas.agent_schemas import AllocationItem
|
|
|
|
| 402 |
liaison_agent = DriverLiaisonAgent()
|
| 403 |
|
| 404 |
final_proposal = state.final_proposal or state.route_proposal_1
|
| 405 |
final_fairness = state.final_fairness or state.fairness_check_1
|
| 406 |
|
| 407 |
+
sorted_allocations = sorted(final_proposal["allocation"], key=lambda x: x["effort"], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
driver_proposals: List[DriverAssignmentProposal] = []
|
| 409 |
for rank, alloc_item in enumerate(sorted_allocations, start=1):
|
| 410 |
driver_proposals.append(DriverAssignmentProposal(
|
| 411 |
+
driver_id=str(alloc_item["driver_id"]), route_id=str(alloc_item["route_id"]),
|
| 412 |
+
effort=alloc_item["effort"], rank_in_team=rank,
|
|
|
|
|
|
|
| 413 |
))
|
| 414 |
|
|
|
|
| 415 |
metrics = final_fairness["metrics"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
driver_context_objs: Dict[str, DriverContext] = {}
|
| 417 |
+
for driver_id, context_dict in (state.driver_contexts or {}).items():
|
| 418 |
driver_context_objs[driver_id] = DriverContext(**context_dict)
|
| 419 |
|
|
|
|
| 420 |
negotiation_result = liaison_agent.run_for_all_drivers(
|
| 421 |
+
proposals=driver_proposals, driver_contexts=driver_context_objs,
|
| 422 |
+
effort_matrix=state.effort_matrix["matrix"], driver_ids=state.effort_matrix["driver_ids"],
|
|
|
|
|
|
|
| 423 |
route_ids=state.effort_matrix["route_ids"],
|
| 424 |
+
global_avg_effort=metrics["avg_effort"], global_std_effort=metrics["std_dev"],
|
|
|
|
| 425 |
)
|
| 426 |
|
|
|
|
| 427 |
liaison_dict = {
|
| 428 |
"decisions": [d.model_dump() if hasattr(d, 'model_dump') else d for d in negotiation_result.decisions],
|
| 429 |
"num_accept": negotiation_result.num_accept,
|
|
|
|
| 431 |
"num_force_accept": negotiation_result.num_force_accept,
|
| 432 |
}
|
| 433 |
|
|
|
|
| 434 |
log_entry = _create_decision_log(
|
| 435 |
+
agent_name="DRIVER_LIAISON", step_type="NEGOTIATION_DECISIONS",
|
| 436 |
+
input_snapshot=liaison_agent.get_input_snapshot(driver_proposals, metrics["avg_effort"], metrics["std_dev"]),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
output_snapshot=liaison_agent.get_output_snapshot(negotiation_result),
|
| 438 |
)
|
| 439 |
|
|
|
|
| 440 |
_publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "COMPLETED", {
|
| 441 |
+
"num_accept": negotiation_result.num_accept, "num_counter": negotiation_result.num_counter,
|
|
|
|
|
|
|
| 442 |
})
|
| 443 |
|
| 444 |
+
return {"liaison_feedback": liaison_dict, "decision_logs": state.decision_logs + [log_entry]}
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
|
| 447 |
# =============================================================================
|
|
|
|
| 449 |
# =============================================================================
|
| 450 |
|
| 451 |
def final_resolution_node(state: AllocationState) -> Dict[str, Any]:
|
| 452 |
+
"""LangGraph node #7: Final Resolution - resolves COUNTER decisions via swaps."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
run_id = state.allocation_run_id
|
| 454 |
from app.schemas.agent_schemas import RoutePlanResult, AllocationItem, FairnessMetrics, DriverLiaisonDecision
|
| 455 |
|
| 456 |
+
counter_decisions = [d for d in state.liaison_feedback["decisions"] if d["decision"] == "COUNTER"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
|
| 458 |
if not counter_decisions:
|
| 459 |
+
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {"reason": "no_counters", "swaps_applied": 0})
|
| 460 |
+
return {"resolution_result": {"swaps_applied": []}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
+
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "STARTED", {"num_counters": len(counter_decisions)})
|
|
|
|
|
|
|
|
|
|
| 463 |
|
| 464 |
resolution_agent = FinalResolutionAgent()
|
|
|
|
|
|
|
| 465 |
final_proposal = state.final_proposal or state.route_proposal_1
|
| 466 |
final_fairness = state.final_fairness or state.fairness_check_1
|
| 467 |
|
| 468 |
approved_proposal = RoutePlanResult(
|
| 469 |
allocation=[AllocationItem(**a) for a in final_proposal["allocation"]],
|
| 470 |
total_effort=final_proposal["total_effort"],
|
| 471 |
+
avg_effort=final_proposal.get("avg_effort", final_proposal["total_effort"] / max(len(final_proposal["allocation"]), 1)),
|
| 472 |
solver_status=final_proposal.get("solver_status", "OPTIMAL"),
|
| 473 |
proposal_number=final_proposal["proposal_number"],
|
| 474 |
per_driver_effort=final_proposal["per_driver_effort"],
|
| 475 |
)
|
| 476 |
|
| 477 |
decisions = [DriverLiaisonDecision(**d) for d in state.liaison_feedback["decisions"]]
|
|
|
|
| 478 |
current_metrics = FairnessMetrics(**final_fairness["metrics"])
|
| 479 |
|
|
|
|
| 480 |
resolution_result = resolution_agent.resolve_counters(
|
| 481 |
+
approved_proposal=approved_proposal, decisions=decisions,
|
|
|
|
| 482 |
effort_matrix=state.effort_matrix["matrix"],
|
| 483 |
+
driver_ids=state.effort_matrix["driver_ids"], route_ids=state.effort_matrix["route_ids"],
|
|
|
|
| 484 |
current_metrics=current_metrics,
|
| 485 |
)
|
| 486 |
|
|
|
|
| 487 |
resolution_dict = {
|
| 488 |
"swaps_applied": [s.model_dump() if hasattr(s, 'model_dump') else s for s in resolution_result.swaps_applied],
|
| 489 |
+
"allocation": resolution_result.allocation,
|
| 490 |
"per_driver_effort": resolution_result.per_driver_effort,
|
| 491 |
"metrics": resolution_result.metrics,
|
| 492 |
}
|
| 493 |
|
|
|
|
| 494 |
log_entry = _create_decision_log(
|
| 495 |
+
agent_name="FINAL_RESOLUTION", step_type="SWAP_RESOLUTION",
|
| 496 |
+
input_snapshot=resolution_agent.get_input_snapshot(len(counter_decisions), current_metrics, final_fairness["metrics"]["avg_effort"]),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
output_snapshot=resolution_agent.get_output_snapshot(resolution_result),
|
| 498 |
)
|
| 499 |
|
| 500 |
+
_publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {"swaps_applied": len(resolution_result.swaps_applied)})
|
|
|
|
|
|
|
|
|
|
| 501 |
|
| 502 |
+
updated_effort = state.final_per_driver_effort.copy() if state.final_per_driver_effort else {}
|
|
|
|
| 503 |
if resolution_result.swaps_applied:
|
| 504 |
updated_effort = resolution_result.per_driver_effort
|
| 505 |
|
| 506 |
+
return {"resolution_result": resolution_dict, "final_per_driver_effort": updated_effort, "decision_logs": state.decision_logs + [log_entry]}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
|
| 509 |
# =============================================================================
|
|
|
|
| 511 |
# =============================================================================
|
| 512 |
|
| 513 |
def explainability_node(state: AllocationState) -> Dict[str, Any]:
|
| 514 |
+
"""LangGraph node #8: Explainability Agent — generates per-driver explanations."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
run_id = state.allocation_run_id
|
| 516 |
|
| 517 |
+
_publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "STARTED", {"num_drivers": len(state.driver_models)})
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
explain_agent = ExplainabilityAgent()
|
|
|
|
| 520 |
final_proposal = state.final_proposal or state.route_proposal_1
|
| 521 |
final_fairness = state.final_fairness or state.fairness_check_1
|
| 522 |
final_per_driver_effort = state.final_per_driver_effort or final_proposal["per_driver_effort"]
|
|
|
|
| 524 |
metrics = final_fairness["metrics"]
|
| 525 |
avg_effort = metrics["avg_effort"]
|
| 526 |
|
|
|
|
| 527 |
route_by_id = {str(r["id"]): r for r in state.route_models}
|
| 528 |
driver_by_id = {str(d["id"]): d for d in state.driver_models}
|
| 529 |
+
route_dict_by_id = {str(r["id"]): rd for r, rd in zip(state.route_models, state.route_dicts)} if state.route_dicts else {}
|
| 530 |
|
| 531 |
+
sorted_efforts = sorted(final_per_driver_effort.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
rank_by_driver = {did: idx + 1 for idx, (did, _) in enumerate(sorted_efforts)}
|
| 533 |
num_drivers = len(final_per_driver_effort)
|
| 534 |
|
|
|
|
| 535 |
liaison_by_driver = {}
|
| 536 |
if state.liaison_feedback:
|
| 537 |
for decision in state.liaison_feedback["decisions"]:
|
| 538 |
liaison_by_driver[decision["driver_id"]] = decision
|
| 539 |
|
|
|
|
| 540 |
swapped_drivers = set()
|
| 541 |
if state.resolution_result and state.resolution_result.get("swaps_applied"):
|
| 542 |
for swap in state.resolution_result["swaps_applied"]:
|
| 543 |
+
swapped_drivers.add(swap.get("driver_a", ""))
|
| 544 |
+
swapped_drivers.add(swap.get("driver_b", ""))
|
| 545 |
|
| 546 |
explanations: Dict[str, Dict[str, Any]] = {}
|
| 547 |
category_counts: Dict[str, int] = {}
|
|
|
|
| 552 |
|
| 553 |
driver = driver_by_id.get(driver_id_str, {})
|
| 554 |
route = route_by_id.get(route_id_str, {})
|
|
|
|
| 555 |
|
|
|
|
| 556 |
effort = final_per_driver_effort.get(driver_id_str, alloc_item["effort"])
|
| 557 |
fairness_score = calculate_fairness_score(effort, avg_effort)
|
| 558 |
+
driver_context = (state.driver_contexts or {}).get(driver_id_str, {})
|
|
|
|
|
|
|
| 559 |
history_efforts = [driver_context.get("recent_avg_effort", avg_effort)] if driver_context else []
|
| 560 |
history_hard_days = driver_context.get("recent_hard_days", 0) if driver_context else 0
|
| 561 |
|
|
|
|
| 562 |
breakdown_key = f"{driver_id_str}:{route_id_str}"
|
| 563 |
effort_breakdown_data = state.effort_matrix.get("breakdown", {}).get(breakdown_key, {})
|
| 564 |
effort_breakdown = {
|
|
|
|
| 567 |
"time_pressure": effort_breakdown_data.get("time_pressure", 0),
|
| 568 |
}
|
| 569 |
|
|
|
|
| 570 |
liaison_decision = liaison_by_driver.get(driver_id_str)
|
| 571 |
+
is_recovery = history_hard_days >= 3 and effort < avg_effort * 0.85
|
| 572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
explain_input = DriverExplanationInput(
|
| 574 |
+
driver_id=driver_id_str, driver_name=driver.get("name", "Driver"),
|
| 575 |
+
num_drivers=num_drivers, today_effort=effort,
|
|
|
|
|
|
|
| 576 |
today_rank=rank_by_driver.get(driver_id_str, num_drivers),
|
| 577 |
route_id=route_id_str,
|
| 578 |
+
route_summary={"num_packages": route.get("num_packages", 0), "total_weight_kg": route.get("total_weight_kg", 0), "num_stops": route.get("num_stops", 0), "difficulty_score": route.get("route_difficulty_score", 0), "estimated_time_minutes": route.get("estimated_time_minutes", 0)},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
effort_breakdown=effort_breakdown,
|
| 580 |
+
global_avg_effort=avg_effort, global_std_effort=metrics["std_dev"],
|
| 581 |
+
global_gini_index=metrics["gini_index"], global_max_gap=metrics["max_gap"],
|
|
|
|
|
|
|
| 582 |
history_efforts_last_7_days=history_efforts,
|
| 583 |
+
history_hard_days_last_7=history_hard_days, is_recovery_day=is_recovery,
|
| 584 |
+
had_manual_override=False,
|
|
|
|
| 585 |
liaison_decision=liaison_decision["decision"] if liaison_decision else None,
|
| 586 |
swap_applied=driver_id_str in swapped_drivers,
|
| 587 |
)
|
| 588 |
|
|
|
|
| 589 |
explain_output = explain_agent.build_explanation_for_driver(explain_input)
|
|
|
|
|
|
|
| 590 |
category_counts[explain_output.category] = category_counts.get(explain_output.category, 0) + 1
|
|
|
|
| 591 |
explanations[driver_id_str] = {
|
| 592 |
"driver_explanation": explain_output.driver_explanation,
|
| 593 |
"admin_explanation": explain_output.admin_explanation,
|
| 594 |
"category": explain_output.category,
|
| 595 |
}
|
| 596 |
|
|
|
|
| 597 |
log_entry = _create_decision_log(
|
| 598 |
+
agent_name="EXPLAINABILITY", step_type="EXPLANATIONS_GENERATED",
|
| 599 |
+
input_snapshot=explain_agent.get_input_snapshot(num_drivers=num_drivers, avg_effort=avg_effort, std_effort=metrics["std_dev"], gini_index=metrics["gini_index"], category_counts=category_counts),
|
| 600 |
+
output_snapshot=explain_agent.get_output_snapshot(total_explanations=len(explanations), category_counts=category_counts),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
)
|
| 602 |
|
| 603 |
+
_publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "COMPLETED", {"total_explanations": len(explanations), "categories": category_counts})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
|
| 605 |
+
return {"explanations": explanations, "decision_logs": state.decision_logs + [log_entry]}
|
|
|
|
|
|
|
|
|
|
| 606 |
|
| 607 |
|
| 608 |
# =============================================================================
|
|
|
|
| 610 |
# =============================================================================
|
| 611 |
|
| 612 |
def should_reoptimize(state: AllocationState) -> str:
|
| 613 |
+
"""Conditional: re-optimize if fairness check 1 says REOPTIMIZE and no proposal 2 yet."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
if state.fairness_check_1 and state.fairness_check_1.get("status") == "REOPTIMIZE":
|
| 615 |
if not state.route_proposal_2:
|
| 616 |
return "reoptimize"
|
|
|
|
| 618 |
|
| 619 |
|
| 620 |
def has_counter_decisions(state: AllocationState) -> str:
|
| 621 |
+
"""Conditional: check if any COUNTER decisions need resolution."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
if state.liaison_feedback:
|
| 623 |
+
if sum(1 for d in state.liaison_feedback["decisions"] if d["decision"] == "COUNTER") > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
return "resolve"
|
| 625 |
return "skip"
|