fix: ModelWrapper.is_ev enum comparison + improve event publish reliability for real-time SSE

#20
Files changed (1) hide show
  1. brain/app/services/langgraph_nodes.py +133 -457
brain/app/services/langgraph_nodes.py CHANGED
@@ -1,12 +1,17 @@
1
  """
2
  LangGraph node wrappers for Fair Dispatch agents.
3
  Each node wraps an existing agent with minimal changes, preserving the original logic.
 
 
 
 
4
  """
5
 
6
  from datetime import datetime
7
  from typing import Dict, Any, List, Optional, Tuple
8
  from uuid import UUID
9
  import asyncio
 
10
 
11
  from app.schemas.allocation_state import AllocationState
12
  from app.schemas.agent_schemas import (
@@ -24,6 +29,8 @@ from app.schemas.explainability import DriverExplanationInput
24
  from app.services.fairness import calculate_fairness_score
25
  from app.core.events import agent_event_bus, make_agent_event
26
 
 
 
27
 
28
  class ModelWrapper:
29
  """Helper to wrap dicts as objects for agent compatibility."""
@@ -35,7 +42,10 @@ class ModelWrapper:
35
 
36
  @property
37
  def is_ev(self) -> bool:
38
- return self._data.get("vehicle_type") == "EV"
 
 
 
39
 
40
 
41
  def _publish_event_sync(
@@ -48,6 +58,8 @@ def _publish_event_sync(
48
  """
49
  Publish an agent event synchronously (fire-and-forget).
50
  Used by LangGraph nodes which are synchronous functions.
 
 
51
  """
52
  if not allocation_run_id:
53
  return
@@ -60,13 +72,17 @@ def _publish_event_sync(
60
  payload=payload,
61
  )
62
 
63
- # Schedule async publish - get or create event loop
64
  try:
65
  loop = asyncio.get_running_loop()
66
- loop.create_task(agent_event_bus.publish(event))
67
  except RuntimeError:
68
- # No running loop, create one for this publish
69
- asyncio.run(agent_event_bus.publish(event))
 
 
 
 
70
 
71
 
72
  def _create_decision_log(
@@ -92,71 +108,49 @@ def _create_decision_log(
92
  def ml_effort_node(state: AllocationState) -> Dict[str, Any]:
93
  """
94
  LangGraph node #1: ML Effort Agent.
95
-
96
  Computes effort matrix for all driver-route pairs using MLEffortAgent.
97
- WRAPS EXISTING AGENT - no logic changes.
98
  """
99
  run_id = state.allocation_run_id
100
 
101
- # Publish STARTED event
102
  _publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "STARTED", {
103
  "num_drivers": len(state.driver_models),
104
  "num_routes": len(state.route_models),
105
  })
106
 
107
- # Initialize agent
108
  ml_agent = MLEffortAgent()
109
 
110
- # Get EV config from state
111
  ev_config = {
112
  "safety_margin_pct": state.config_used.get("ev_safety_margin_pct", 10.0) if state.config_used else 10.0,
113
  "charging_penalty_weight": state.config_used.get("ev_charging_penalty_weight", 0.3) if state.config_used else 0.3,
114
  }
115
 
116
- # Wrap dicts as objects for agent compatibility
117
  drivers = [ModelWrapper(d) for d in state.driver_models]
118
  routes = [ModelWrapper(r) for r in state.route_models]
119
 
120
- # Compute effort matrix (EXISTING CODE - UNCHANGED)
121
- effort_result = ml_agent.compute_effort_matrix(
122
- drivers=drivers,
123
- routes=routes,
124
- ev_config=ev_config,
125
- )
126
 
127
- # Serialize result for state
128
  effort_dict = {
129
  "matrix": effort_result.matrix,
130
  "driver_ids": effort_result.driver_ids,
131
  "route_ids": effort_result.route_ids,
132
- "breakdown": {k: v.model_dump() if hasattr(v, 'model_dump') else v
133
- for k, v in effort_result.breakdown.items()},
134
  "stats": effort_result.stats,
135
  "infeasible_pairs": list(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else [],
136
  }
137
 
138
- # Create decision log
139
  log_entry = _create_decision_log(
140
- agent_name="ML_EFFORT",
141
- step_type="MATRIX_GENERATION",
142
  input_snapshot=ml_agent.get_input_snapshot(drivers, routes),
143
- output_snapshot={
144
- **ml_agent.get_output_snapshot(effort_result),
145
- "num_infeasible_ev_pairs": len(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else 0,
146
- },
147
  )
148
 
149
- # Publish COMPLETED event
150
  _publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "COMPLETED", {
151
  "min_effort": effort_result.stats.get("min", 0),
152
  "max_effort": effort_result.stats.get("max", 0),
153
  "avg_effort": effort_result.stats.get("avg", 0),
154
  })
155
 
156
- return {
157
- "effort_matrix": effort_dict,
158
- "decision_logs": state.decision_logs + [log_entry],
159
- }
160
 
161
 
162
  # =============================================================================
@@ -164,91 +158,54 @@ def ml_effort_node(state: AllocationState) -> Dict[str, Any]:
164
  # =============================================================================
165
 
166
  def route_planner_node(state: AllocationState) -> Dict[str, Any]:
167
- """
168
- LangGraph node #2: Route Planner Agent - Proposal 1.
169
-
170
- Generates optimal driver-route assignment using OR-Tools.
171
- WRAPS EXISTING AGENT - no logic changes.
172
- """
173
  run_id = state.allocation_run_id
174
 
175
- # Publish STARTED event
176
  _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "STARTED", {
177
- "num_drivers": len(state.driver_models),
178
- "num_routes": len(state.route_models),
179
  })
180
 
181
  planner_agent = RoutePlannerAgent()
 
182
 
183
- # Reconstruct EffortMatrixResult-like object for planner
184
- from app.schemas.agent_schemas import EffortMatrixResult, EffortBreakdown
185
-
186
- # Use stats from serialized state or compute if not available
187
  matrix = state.effort_matrix["matrix"]
188
- stats = state.effort_matrix.get("stats")
189
- if not stats:
190
- all_values = [v for row in matrix for v in row if v < float('inf')]
191
- stats = {
192
- "min": min(all_values) if all_values else 0.0,
193
- "max": max(all_values) if all_values else 0.0,
194
- "avg": sum(all_values) / len(all_values) if all_values else 0.0,
195
- }
196
 
197
  effort_result = EffortMatrixResult(
198
- matrix=matrix,
199
- driver_ids=state.effort_matrix["driver_ids"],
200
- route_ids=state.effort_matrix["route_ids"],
201
- breakdown={}, # Simplified - full breakdown not needed for planning
202
- stats=stats,
203
  infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
204
  )
205
 
206
- # Get recovery penalty weight
207
  recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
208
-
209
- # Wrap dicts as objects for agent compatibility
210
  drivers = [ModelWrapper(d) for d in state.driver_models]
211
  routes = [ModelWrapper(r) for r in state.route_models]
212
 
213
- # Generate Proposal 1 (EXISTING CODE - UNCHANGED)
214
  proposal1 = planner_agent.plan(
215
- effort_result=effort_result,
216
- drivers=drivers,
217
- routes=routes,
218
  recovery_targets=state.recovery_targets or {},
219
- recovery_penalty_weight=recovery_penalty_weight,
220
- proposal_number=1,
221
  )
222
 
223
- # Serialize result
224
  proposal_dict = {
225
  "allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal1.allocation],
226
- "total_effort": proposal1.total_effort,
227
- "avg_effort": proposal1.avg_effort,
228
- "solver_status": proposal1.solver_status,
229
- "proposal_number": proposal1.proposal_number,
230
  "per_driver_effort": proposal1.per_driver_effort,
231
  }
232
 
233
- # Create decision log
234
  log_entry = _create_decision_log(
235
- agent_name="ROUTE_PLANNER",
236
- step_type="PROPOSAL_1",
237
  input_snapshot=planner_agent.get_input_snapshot(effort_result),
238
  output_snapshot=planner_agent.get_output_snapshot(proposal1),
239
  )
240
 
241
- # Publish COMPLETED event
242
  _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "COMPLETED", {
243
- "total_effort": proposal1.total_effort,
244
- "num_assignments": len(proposal1.allocation),
245
  "solver_status": proposal1.solver_status,
246
  })
247
 
248
- return {
249
- "route_proposal_1": proposal_dict,
250
- "decision_logs": state.decision_logs + [log_entry],
251
- }
252
 
253
 
254
  # =============================================================================
@@ -256,21 +213,11 @@ def route_planner_node(state: AllocationState) -> Dict[str, Any]:
256
  # =============================================================================
257
 
258
  def fairness_check_node(state: AllocationState) -> Dict[str, Any]:
259
- """
260
- LangGraph node #3: Fairness Manager Agent.
261
-
262
- Evaluates fairness metrics and decides ACCEPT or REOPTIMIZE.
263
- WRAPS EXISTING AGENT - no logic changes.
264
- """
265
  run_id = state.allocation_run_id
266
- proposal_number = 2 if state.route_proposal_2 else 1
267
 
268
- # Publish STARTED event
269
- _publish_event_sync(run_id, "FAIRNESS_MANAGER", f"FAIRNESS_CHECK_{proposal_number}", "STARTED", {
270
- "proposal_number": proposal_number,
271
- })
272
 
273
- # Get thresholds from config
274
  thresholds = FairnessThresholds(
275
  gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
276
  stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
@@ -278,84 +225,48 @@ def fairness_check_node(state: AllocationState) -> Dict[str, Any]:
278
  )
279
 
280
  fairness_agent = FairnessManagerAgent(thresholds=thresholds)
281
-
282
- # Reconstruct RoutePlanResult for fairness check
283
  from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
284
 
285
- # Determine which proposal to check
286
- proposal_to_check = state.route_proposal_2 or state.route_proposal_1
287
-
288
  plan_result = RoutePlanResult(
289
  allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
290
  total_effort=proposal_to_check["total_effort"],
291
- avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / len(proposal_to_check["allocation"]) if proposal_to_check["allocation"] else 0.0),
292
  solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
293
- proposal_number=proposal_number,
294
- per_driver_effort=proposal_to_check["per_driver_effort"],
295
  )
296
 
297
- # Check fairness (EXISTING CODE - UNCHANGED)
298
- fairness_result = fairness_agent.check(plan_result, proposal_number=proposal_number)
299
 
300
- # Serialize result
301
  fairness_dict = {
302
- "status": fairness_result.status,
303
- "proposal_number": fairness_result.proposal_number,
304
  "metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
305
- "avg_effort": fairness_result.metrics.avg_effort,
306
- "std_dev": fairness_result.metrics.std_dev,
307
- "gini_index": fairness_result.metrics.gini_index,
308
- "max_effort": fairness_result.metrics.max_effort,
309
- "min_effort": fairness_result.metrics.min_effort,
310
- "max_gap": fairness_result.metrics.max_gap,
311
  },
312
  "recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
313
  }
314
 
315
- # Create decision log
316
  log_entry = _create_decision_log(
317
- agent_name="FAIRNESS_MANAGER",
318
- step_type=f"FAIRNESS_CHECK_PROPOSAL_{proposal_number}",
319
  input_snapshot=fairness_agent.get_input_snapshot(plan_result),
320
  output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
321
  )
322
 
323
- # Publish COMPLETED event
324
- _publish_event_sync(run_id, "FAIRNESS_MANAGER", f"FAIRNESS_CHECK_{proposal_number}", "COMPLETED", {
325
- "status": fairness_result.status,
326
- "gini_index": fairness_dict["metrics"]["gini_index"],
327
- "std_dev": fairness_dict["metrics"]["std_dev"],
328
  })
329
 
330
- # Update appropriate check result based on proposal number
331
- updates = {
332
- "decision_logs": state.decision_logs + [log_entry],
333
- }
334
-
335
- if proposal_number == 1:
336
- updates["fairness_check_1"] = fairness_dict
337
- else:
338
- updates["fairness_check_2"] = fairness_dict
339
-
340
- return updates
341
 
342
 
343
  def fairness_check_2_node(state: AllocationState) -> Dict[str, Any]:
344
- """
345
- LangGraph node for second fairness check (dedicated wrapper).
346
-
347
- This is a separate function to avoid LangGraph state key conflicts
348
- when the same function is used for multiple nodes.
349
- """
350
  run_id = state.allocation_run_id
351
- proposal_number = 2 # Always proposal 2 for this node
352
 
353
- # Publish STARTED event
354
- _publish_event_sync(run_id, "FAIRNESS_MANAGER", f"FAIRNESS_CHECK_{proposal_number}", "STARTED", {
355
- "proposal_number": proposal_number,
356
- })
357
 
358
- # Get thresholds from config
359
  thresholds = FairnessThresholds(
360
  gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
361
  stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
@@ -363,164 +274,98 @@ def fairness_check_2_node(state: AllocationState) -> Dict[str, Any]:
363
  )
364
 
365
  fairness_agent = FairnessManagerAgent(thresholds=thresholds)
366
-
367
- # Reconstruct RoutePlanResult for fairness check
368
  from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
369
 
370
- # Always check proposal 2 for this node
371
  proposal_to_check = state.route_proposal_2
372
-
373
  plan_result = RoutePlanResult(
374
  allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
375
  total_effort=proposal_to_check["total_effort"],
376
- avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / len(proposal_to_check["allocation"]) if proposal_to_check["allocation"] else 0.0),
377
  solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
378
- proposal_number=proposal_number,
379
- per_driver_effort=proposal_to_check["per_driver_effort"],
380
  )
381
 
382
- # Check fairness
383
- fairness_result = fairness_agent.check(plan_result, proposal_number=proposal_number)
384
 
385
- # Serialize result
386
  fairness_dict = {
387
- "status": fairness_result.status,
388
- "proposal_number": fairness_result.proposal_number,
389
  "metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
390
- "avg_effort": fairness_result.metrics.avg_effort,
391
- "std_dev": fairness_result.metrics.std_dev,
392
- "gini_index": fairness_result.metrics.gini_index,
393
- "max_effort": fairness_result.metrics.max_effort,
394
- "min_effort": fairness_result.metrics.min_effort,
395
- "max_gap": fairness_result.metrics.max_gap,
396
  },
397
  "recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
398
  }
399
 
400
- # Create decision log
401
  log_entry = _create_decision_log(
402
- agent_name="FAIRNESS_MANAGER",
403
- step_type=f"FAIRNESS_CHECK_PROPOSAL_{proposal_number}",
404
  input_snapshot=fairness_agent.get_input_snapshot(plan_result),
405
  output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
406
  )
407
 
408
- # Publish COMPLETED event
409
- _publish_event_sync(run_id, "FAIRNESS_MANAGER", f"FAIRNESS_CHECK_{proposal_number}", "COMPLETED", {
410
- "status": fairness_result.status,
411
- "gini_index": fairness_dict["metrics"]["gini_index"],
412
- "std_dev": fairness_dict["metrics"]["std_dev"],
413
  })
414
 
415
- return {
416
- "fairness_check_2": fairness_dict,
417
- "decision_logs": state.decision_logs + [log_entry],
418
- }
419
 
420
 
421
  # =============================================================================
422
- # Node 4: Route Planner Agent (Proposal 2 - with fairness penalties)
423
  # =============================================================================
424
 
425
-
426
  def route_planner_reoptimize_node(state: AllocationState) -> Dict[str, Any]:
427
- """
428
- LangGraph node #4: Route Planner Agent - Proposal 2 (re-optimization).
429
-
430
- Re-runs OR-Tools with fairness penalties applied.
431
- WRAPS EXISTING AGENT - no logic changes.
432
- """
433
  run_id = state.allocation_run_id
434
 
435
- # Publish STARTED event
436
- _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "STARTED", {
437
- "reason": "fairness_reoptimization",
438
- })
439
 
440
  planner_agent = RoutePlannerAgent()
441
-
442
- # Reconstruct effort result
443
  from app.schemas.agent_schemas import EffortMatrixResult, FairnessRecommendations
444
 
445
- # Use stats from serialized state or compute if not available
446
  matrix = state.effort_matrix["matrix"]
447
- stats = state.effort_matrix.get("stats")
448
- if not stats:
449
- all_values = [v for row in matrix for v in row if v < float('inf')]
450
- stats = {
451
- "min": min(all_values) if all_values else 0.0,
452
- "max": max(all_values) if all_values else 0.0,
453
- "avg": sum(all_values) / len(all_values) if all_values else 0.0,
454
- }
455
 
456
  effort_result = EffortMatrixResult(
457
- matrix=matrix,
458
- driver_ids=state.effort_matrix["driver_ids"],
459
- route_ids=state.effort_matrix["route_ids"],
460
- breakdown={},
461
- stats=stats,
462
  infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
463
  )
464
 
465
- # Build penalties from recommendations
466
  recommendations_dict = state.fairness_check_1.get("recommendations")
467
  penalties = {}
468
-
469
  if recommendations_dict:
470
  recommendations = FairnessRecommendations(**recommendations_dict)
471
- penalties = planner_agent.build_penalties_from_recommendations(
472
- recommendations,
473
- state.route_proposal_1["per_driver_effort"],
474
- )
475
 
476
- # Get recovery settings
477
  recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
478
-
479
- # Wrap dicts as objects for agent compatibility
480
  drivers = [ModelWrapper(d) for d in state.driver_models]
481
  routes = [ModelWrapper(r) for r in state.route_models]
482
 
483
- # Generate Proposal 2 (EXISTING CODE - UNCHANGED)
484
  proposal2 = planner_agent.plan(
485
- effort_result=effort_result,
486
- drivers=drivers,
487
- routes=routes,
488
- fairness_penalties=penalties,
489
- recovery_targets=state.recovery_targets or {},
490
- recovery_penalty_weight=recovery_penalty_weight,
491
- proposal_number=2,
492
  )
493
 
494
- # Serialize result
495
  proposal_dict = {
496
  "allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal2.allocation],
497
- "total_effort": proposal2.total_effort,
498
- "avg_effort": proposal2.avg_effort,
499
- "solver_status": proposal2.solver_status,
500
- "proposal_number": proposal2.proposal_number,
501
  "per_driver_effort": proposal2.per_driver_effort,
502
  }
503
 
504
- # Create decision log
505
  log_entry = _create_decision_log(
506
- agent_name="ROUTE_PLANNER",
507
- step_type="PROPOSAL_2",
508
  input_snapshot=planner_agent.get_input_snapshot(effort_result, penalties),
509
  output_snapshot=planner_agent.get_output_snapshot(proposal2),
510
  )
511
 
512
- # Publish COMPLETED event
513
  _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "COMPLETED", {
514
- "total_effort": proposal2.total_effort,
515
- "num_assignments": len(proposal2.allocation),
516
- "solver_status": proposal2.solver_status,
517
  })
518
 
519
- return {
520
- "route_proposal_2": proposal_dict,
521
- "decision_logs": state.decision_logs + [log_entry],
522
- }
523
-
524
 
525
 
526
  # =============================================================================
@@ -528,31 +373,19 @@ def route_planner_reoptimize_node(state: AllocationState) -> Dict[str, Any]:
528
  # =============================================================================
529
 
530
  def select_final_proposal_node(state: AllocationState) -> Dict[str, Any]:
531
- """
532
- Select the final proposal after fairness checks.
533
-
534
- If proposal 2 exists and has better fairness, use it.
535
- Otherwise, use proposal 1.
536
- """
537
  final_proposal = state.route_proposal_1
538
  final_fairness = state.fairness_check_1
539
 
540
  if state.route_proposal_2 and state.fairness_check_2:
541
- # Compare fairness metrics
542
  check1_metrics = state.fairness_check_1["metrics"]
543
  check2_metrics = state.fairness_check_2["metrics"]
544
-
545
- # Use proposal 2 if it improves fairness
546
  if (check2_metrics["gini_index"] <= check1_metrics["gini_index"] or
547
  check2_metrics["max_gap"] < check1_metrics["max_gap"]):
548
  final_proposal = state.route_proposal_2
549
  final_fairness = state.fairness_check_2
550
 
551
- return {
552
- "final_proposal": final_proposal,
553
- "final_fairness": final_fairness,
554
- "final_per_driver_effort": final_proposal["per_driver_effort"],
555
- }
556
 
557
 
558
  # =============================================================================
@@ -560,64 +393,37 @@ def select_final_proposal_node(state: AllocationState) -> Dict[str, Any]:
560
  # =============================================================================
561
 
562
  def driver_liaison_node(state: AllocationState) -> Dict[str, Any]:
563
- """
564
- LangGraph node #6: Driver Liaison Agent.
565
-
566
- Reviews proposed assignments and makes ACCEPT/COUNTER decisions per driver.
567
- WRAPS EXISTING AGENT - no logic changes.
568
- """
569
  run_id = state.allocation_run_id
570
 
571
- # Publish STARTED event
572
- _publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "STARTED", {
573
- "num_drivers": len(state.driver_models),
574
- })
575
 
576
  from app.schemas.agent_schemas import AllocationItem
577
-
578
  liaison_agent = DriverLiaisonAgent()
579
 
580
  final_proposal = state.final_proposal or state.route_proposal_1
581
  final_fairness = state.final_fairness or state.fairness_check_1
582
 
583
- # Build DriverAssignmentProposals with ranking
584
- sorted_allocations = sorted(
585
- final_proposal["allocation"],
586
- key=lambda x: x["effort"],
587
- reverse=True # Highest effort = rank 1
588
- )
589
-
590
  driver_proposals: List[DriverAssignmentProposal] = []
591
  for rank, alloc_item in enumerate(sorted_allocations, start=1):
592
  driver_proposals.append(DriverAssignmentProposal(
593
- driver_id=str(alloc_item["driver_id"]),
594
- route_id=str(alloc_item["route_id"]),
595
- effort=alloc_item["effort"],
596
- rank_in_team=rank,
597
  ))
598
 
599
- # Get global metrics
600
  metrics = final_fairness["metrics"]
601
- global_avg_effort = metrics["avg_effort"]
602
- global_std_effort = metrics["std_dev"]
603
-
604
- # Build DriverContext objects
605
  driver_context_objs: Dict[str, DriverContext] = {}
606
- for driver_id, context_dict in state.driver_contexts.items():
607
  driver_context_objs[driver_id] = DriverContext(**context_dict)
608
 
609
- # Run liaison for all drivers (EXISTING CODE - UNCHANGED)
610
  negotiation_result = liaison_agent.run_for_all_drivers(
611
- proposals=driver_proposals,
612
- driver_contexts=driver_context_objs,
613
- effort_matrix=state.effort_matrix["matrix"],
614
- driver_ids=state.effort_matrix["driver_ids"],
615
  route_ids=state.effort_matrix["route_ids"],
616
- global_avg_effort=global_avg_effort,
617
- global_std_effort=global_std_effort,
618
  )
619
 
620
- # Serialize result
621
  liaison_dict = {
622
  "decisions": [d.model_dump() if hasattr(d, 'model_dump') else d for d in negotiation_result.decisions],
623
  "num_accept": negotiation_result.num_accept,
@@ -625,29 +431,17 @@ def driver_liaison_node(state: AllocationState) -> Dict[str, Any]:
625
  "num_force_accept": negotiation_result.num_force_accept,
626
  }
627
 
628
- # Create decision log
629
  log_entry = _create_decision_log(
630
- agent_name="DRIVER_LIAISON",
631
- step_type="NEGOTIATION_DECISIONS",
632
- input_snapshot=liaison_agent.get_input_snapshot(
633
- driver_proposals,
634
- global_avg_effort,
635
- global_std_effort,
636
- ),
637
  output_snapshot=liaison_agent.get_output_snapshot(negotiation_result),
638
  )
639
 
640
- # Publish COMPLETED event
641
  _publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "COMPLETED", {
642
- "num_accept": negotiation_result.num_accept,
643
- "num_counter": negotiation_result.num_counter,
644
- "num_force_accept": negotiation_result.num_force_accept,
645
  })
646
 
647
- return {
648
- "liaison_feedback": liaison_dict,
649
- "decision_logs": state.decision_logs + [log_entry],
650
- }
651
 
652
 
653
  # =============================================================================
@@ -655,101 +449,61 @@ def driver_liaison_node(state: AllocationState) -> Dict[str, Any]:
655
  # =============================================================================
656
 
657
  def final_resolution_node(state: AllocationState) -> Dict[str, Any]:
658
- """
659
- LangGraph node #7: Final Resolution Agent.
660
-
661
- Resolves COUNTER decisions through swaps.
662
- WRAPS EXISTING AGENT - no logic changes.
663
- """
664
  run_id = state.allocation_run_id
665
  from app.schemas.agent_schemas import RoutePlanResult, AllocationItem, FairnessMetrics, DriverLiaisonDecision
666
 
667
- # Check if there are any COUNTER decisions
668
- counter_decisions = [
669
- d for d in state.liaison_feedback["decisions"]
670
- if d["decision"] == "COUNTER"
671
- ]
672
 
673
  if not counter_decisions:
674
- # Publish SKIPPED event
675
- _publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {
676
- "reason": "no_counters",
677
- "swaps_applied": 0,
678
- })
679
- # No resolution needed
680
- return {
681
- "resolution_result": {"swaps_applied": []},
682
- }
683
 
684
- # Publish STARTED event
685
- _publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "STARTED", {
686
- "num_counters": len(counter_decisions),
687
- })
688
 
689
  resolution_agent = FinalResolutionAgent()
690
-
691
- # Reconstruct objects for resolution
692
  final_proposal = state.final_proposal or state.route_proposal_1
693
  final_fairness = state.final_fairness or state.fairness_check_1
694
 
695
  approved_proposal = RoutePlanResult(
696
  allocation=[AllocationItem(**a) for a in final_proposal["allocation"]],
697
  total_effort=final_proposal["total_effort"],
698
- avg_effort=final_proposal.get("avg_effort", final_proposal["total_effort"] / len(final_proposal["allocation"]) if final_proposal["allocation"] else 0.0),
699
  solver_status=final_proposal.get("solver_status", "OPTIMAL"),
700
  proposal_number=final_proposal["proposal_number"],
701
  per_driver_effort=final_proposal["per_driver_effort"],
702
  )
703
 
704
  decisions = [DriverLiaisonDecision(**d) for d in state.liaison_feedback["decisions"]]
705
-
706
  current_metrics = FairnessMetrics(**final_fairness["metrics"])
707
 
708
- # Resolve counters (EXISTING CODE - UNCHANGED)
709
  resolution_result = resolution_agent.resolve_counters(
710
- approved_proposal=approved_proposal,
711
- decisions=decisions,
712
  effort_matrix=state.effort_matrix["matrix"],
713
- driver_ids=state.effort_matrix["driver_ids"],
714
- route_ids=state.effort_matrix["route_ids"],
715
  current_metrics=current_metrics,
716
  )
717
 
718
- # Serialize result
719
  resolution_dict = {
720
  "swaps_applied": [s.model_dump() if hasattr(s, 'model_dump') else s for s in resolution_result.swaps_applied],
721
- "allocation": resolution_result.allocation, # Already a list of dicts
722
  "per_driver_effort": resolution_result.per_driver_effort,
723
  "metrics": resolution_result.metrics,
724
  }
725
 
726
- # Create decision log
727
  log_entry = _create_decision_log(
728
- agent_name="ROUTE_PLANNER",
729
- step_type="FINAL_RESOLUTION",
730
- input_snapshot=resolution_agent.get_input_snapshot(
731
- len(counter_decisions),
732
- current_metrics,
733
- final_fairness["metrics"]["avg_effort"],
734
- ),
735
  output_snapshot=resolution_agent.get_output_snapshot(resolution_result),
736
  )
737
 
738
- # Publish COMPLETED event
739
- _publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {
740
- "swaps_applied": len(resolution_result.swaps_applied),
741
- })
742
 
743
- # Update per-driver effort if swaps were applied
744
- updated_effort = state.final_per_driver_effort.copy()
745
  if resolution_result.swaps_applied:
746
  updated_effort = resolution_result.per_driver_effort
747
 
748
- return {
749
- "resolution_result": resolution_dict,
750
- "final_per_driver_effort": updated_effort,
751
- "decision_logs": state.decision_logs + [log_entry],
752
- }
753
 
754
 
755
  # =============================================================================
@@ -757,21 +511,12 @@ def final_resolution_node(state: AllocationState) -> Dict[str, Any]:
757
  # =============================================================================
758
 
759
  def explainability_node(state: AllocationState) -> Dict[str, Any]:
760
- """
761
- LangGraph node #8: Explainability Agent.
762
-
763
- Generates template-based explanations for each driver.
764
- WRAPS EXISTING AGENT - no logic changes.
765
- """
766
  run_id = state.allocation_run_id
767
 
768
- # Publish STARTED event
769
- _publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "STARTED", {
770
- "num_drivers": len(state.driver_models),
771
- })
772
 
773
  explain_agent = ExplainabilityAgent()
774
-
775
  final_proposal = state.final_proposal or state.route_proposal_1
776
  final_fairness = state.final_fairness or state.fairness_check_1
777
  final_per_driver_effort = state.final_per_driver_effort or final_proposal["per_driver_effort"]
@@ -779,32 +524,24 @@ def explainability_node(state: AllocationState) -> Dict[str, Any]:
779
  metrics = final_fairness["metrics"]
780
  avg_effort = metrics["avg_effort"]
781
 
782
- # Build lookup structures
783
  route_by_id = {str(r["id"]): r for r in state.route_models}
784
  driver_by_id = {str(d["id"]): d for d in state.driver_models}
785
- route_dict_by_id = {str(r["id"]): rd for r, rd in zip(state.route_models, state.route_dicts)}
786
 
787
- # Compute per-driver ranks
788
- sorted_efforts = sorted(
789
- final_per_driver_effort.items(),
790
- key=lambda x: x[1],
791
- reverse=True
792
- )
793
  rank_by_driver = {did: idx + 1 for idx, (did, _) in enumerate(sorted_efforts)}
794
  num_drivers = len(final_per_driver_effort)
795
 
796
- # Build liaison decisions lookup
797
  liaison_by_driver = {}
798
  if state.liaison_feedback:
799
  for decision in state.liaison_feedback["decisions"]:
800
  liaison_by_driver[decision["driver_id"]] = decision
801
 
802
- # Build swaps lookup
803
  swapped_drivers = set()
804
  if state.resolution_result and state.resolution_result.get("swaps_applied"):
805
  for swap in state.resolution_result["swaps_applied"]:
806
- swapped_drivers.add(swap["driver_a"])
807
- swapped_drivers.add(swap["driver_b"])
808
 
809
  explanations: Dict[str, Dict[str, Any]] = {}
810
  category_counts: Dict[str, int] = {}
@@ -815,18 +552,13 @@ def explainability_node(state: AllocationState) -> Dict[str, Any]:
815
 
816
  driver = driver_by_id.get(driver_id_str, {})
817
  route = route_by_id.get(route_id_str, {})
818
- route_dict = route_dict_by_id.get(route_id_str, {})
819
 
820
- # Use resolved effort if available
821
  effort = final_per_driver_effort.get(driver_id_str, alloc_item["effort"])
822
  fairness_score = calculate_fairness_score(effort, avg_effort)
823
-
824
- # Get driver context
825
- driver_context = state.driver_contexts.get(driver_id_str, {})
826
  history_efforts = [driver_context.get("recent_avg_effort", avg_effort)] if driver_context else []
827
  history_hard_days = driver_context.get("recent_hard_days", 0) if driver_context else 0
828
 
829
- # Get effort breakdown
830
  breakdown_key = f"{driver_id_str}:{route_id_str}"
831
  effort_breakdown_data = state.effort_matrix.get("breakdown", {}).get(breakdown_key, {})
832
  effort_breakdown = {
@@ -835,82 +567,42 @@ def explainability_node(state: AllocationState) -> Dict[str, Any]:
835
  "time_pressure": effort_breakdown_data.get("time_pressure", 0),
836
  }
837
 
838
- # Get liaison decision
839
  liaison_decision = liaison_by_driver.get(driver_id_str)
 
840
 
841
- # Determine if recovery day
842
- is_recovery = (
843
- history_hard_days >= 3 and
844
- effort < avg_effort * 0.85
845
- )
846
-
847
- # Build explanation input
848
  explain_input = DriverExplanationInput(
849
- driver_id=driver_id_str,
850
- driver_name=driver.get("name", "Driver"),
851
- num_drivers=num_drivers,
852
- today_effort=effort,
853
  today_rank=rank_by_driver.get(driver_id_str, num_drivers),
854
  route_id=route_id_str,
855
- route_summary={
856
- "num_packages": route.get("num_packages", 0),
857
- "total_weight_kg": route.get("total_weight_kg", 0),
858
- "num_stops": route.get("num_stops", 0),
859
- "difficulty_score": route.get("route_difficulty_score", 0),
860
- "estimated_time_minutes": route.get("estimated_time_minutes", 0),
861
- },
862
  effort_breakdown=effort_breakdown,
863
- global_avg_effort=avg_effort,
864
- global_std_effort=metrics["std_dev"],
865
- global_gini_index=metrics["gini_index"],
866
- global_max_gap=metrics["max_gap"],
867
  history_efforts_last_7_days=history_efforts,
868
- history_hard_days_last_7=history_hard_days,
869
- is_recovery_day=is_recovery,
870
- had_manual_override=False, # TODO: Query DB if needed
871
  liaison_decision=liaison_decision["decision"] if liaison_decision else None,
872
  swap_applied=driver_id_str in swapped_drivers,
873
  )
874
 
875
- # Generate explanations (EXISTING CODE - UNCHANGED)
876
  explain_output = explain_agent.build_explanation_for_driver(explain_input)
877
-
878
- # Track category counts
879
  category_counts[explain_output.category] = category_counts.get(explain_output.category, 0) + 1
880
-
881
  explanations[driver_id_str] = {
882
  "driver_explanation": explain_output.driver_explanation,
883
  "admin_explanation": explain_output.admin_explanation,
884
  "category": explain_output.category,
885
  }
886
 
887
- # Create decision log
888
  log_entry = _create_decision_log(
889
- agent_name="EXPLAINABILITY",
890
- step_type="EXPLANATIONS_GENERATED",
891
- input_snapshot=explain_agent.get_input_snapshot(
892
- num_drivers=num_drivers,
893
- avg_effort=avg_effort,
894
- std_effort=metrics["std_dev"],
895
- gini_index=metrics["gini_index"],
896
- category_counts=category_counts,
897
- ),
898
- output_snapshot=explain_agent.get_output_snapshot(
899
- total_explanations=len(explanations),
900
- category_counts=category_counts,
901
- ),
902
  )
903
 
904
- # Publish COMPLETED event
905
- _publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "COMPLETED", {
906
- "total_explanations": len(explanations),
907
- "categories": category_counts,
908
- })
909
 
910
- return {
911
- "explanations": explanations,
912
- "decision_logs": state.decision_logs + [log_entry],
913
- }
914
 
915
 
916
  # =============================================================================
@@ -918,13 +610,7 @@ def explainability_node(state: AllocationState) -> Dict[str, Any]:
918
  # =============================================================================
919
 
920
  def should_reoptimize(state: AllocationState) -> str:
921
- """
922
- Conditional edge: decide if re-optimization is needed.
923
-
924
- Returns:
925
- "reoptimize" - if fairness check 1 says REOPTIMIZE and no proposal 2 yet
926
- "continue" - otherwise
927
- """
928
  if state.fairness_check_1 and state.fairness_check_1.get("status") == "REOPTIMIZE":
929
  if not state.route_proposal_2:
930
  return "reoptimize"
@@ -932,18 +618,8 @@ def should_reoptimize(state: AllocationState) -> str:
932
 
933
 
934
  def has_counter_decisions(state: AllocationState) -> str:
935
- """
936
- Conditional edge: check if any COUNTER decisions need resolution.
937
-
938
- Returns:
939
- "resolve" - if there are COUNTER decisions
940
- "skip" - otherwise
941
- """
942
  if state.liaison_feedback:
943
- counter_count = sum(
944
- 1 for d in state.liaison_feedback["decisions"]
945
- if d["decision"] == "COUNTER"
946
- )
947
- if counter_count > 0:
948
  return "resolve"
949
  return "skip"
 
1
  """
2
  LangGraph node wrappers for Fair Dispatch agents.
3
  Each node wraps an existing agent with minimal changes, preserving the original logic.
4
+
5
+ PRODUCTION FIXES APPLIED:
6
+ - ModelWrapper.is_ev: handles "EV", "ELECTRIC", VehicleType.EV enum values
7
+ - _publish_event_sync: improved reliability with asyncio.ensure_future
8
  """
9
 
10
  from datetime import datetime
11
  from typing import Dict, Any, List, Optional, Tuple
12
  from uuid import UUID
13
  import asyncio
14
+ import logging
15
 
16
  from app.schemas.allocation_state import AllocationState
17
  from app.schemas.agent_schemas import (
 
29
  from app.services.fairness import calculate_fairness_score
30
  from app.core.events import agent_event_bus, make_agent_event
31
 
32
+ logger = logging.getLogger("fairrelay.langgraph")
33
+
34
 
35
  class ModelWrapper:
36
  """Helper to wrap dicts as objects for agent compatibility."""
 
42
 
43
  @property
44
  def is_ev(self) -> bool:
45
+ """Check if driver has an EV - handles all possible enum/string formats."""
46
+ vt = self._data.get("vehicle_type", "")
47
+ vt_str = str(vt).upper()
48
+ return vt_str in ("EV", "ELECTRIC", "VEHICLETYPE.EV")
49
 
50
 
51
  def _publish_event_sync(
 
58
  """
59
  Publish an agent event synchronously (fire-and-forget).
60
  Used by LangGraph nodes which are synchronous functions.
61
+
62
+ Uses asyncio.ensure_future for reliable delivery when a loop is running.
63
  """
64
  if not allocation_run_id:
65
  return
 
72
  payload=payload,
73
  )
74
 
75
+ # Schedule async publish on the running event loop
76
  try:
77
  loop = asyncio.get_running_loop()
78
+ asyncio.ensure_future(agent_event_bus.publish(event), loop=loop)
79
  except RuntimeError:
80
+ # No running loop this shouldn't happen in FastAPI context
81
+ # but handle gracefully for testing
82
+ try:
83
+ asyncio.run(agent_event_bus.publish(event))
84
+ except Exception as e:
85
+ logger.warning(f"Failed to publish agent event: {e}")
86
 
87
 
88
  def _create_decision_log(
 
108
  def ml_effort_node(state: AllocationState) -> Dict[str, Any]:
109
  """
110
  LangGraph node #1: ML Effort Agent.
 
111
  Computes effort matrix for all driver-route pairs using MLEffortAgent.
 
112
  """
113
  run_id = state.allocation_run_id
114
 
 
115
  _publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "STARTED", {
116
  "num_drivers": len(state.driver_models),
117
  "num_routes": len(state.route_models),
118
  })
119
 
 
120
  ml_agent = MLEffortAgent()
121
 
 
122
  ev_config = {
123
  "safety_margin_pct": state.config_used.get("ev_safety_margin_pct", 10.0) if state.config_used else 10.0,
124
  "charging_penalty_weight": state.config_used.get("ev_charging_penalty_weight", 0.3) if state.config_used else 0.3,
125
  }
126
 
 
127
  drivers = [ModelWrapper(d) for d in state.driver_models]
128
  routes = [ModelWrapper(r) for r in state.route_models]
129
 
130
+ effort_result = ml_agent.compute_effort_matrix(drivers=drivers, routes=routes, ev_config=ev_config)
 
 
 
 
 
131
 
 
132
  effort_dict = {
133
  "matrix": effort_result.matrix,
134
  "driver_ids": effort_result.driver_ids,
135
  "route_ids": effort_result.route_ids,
136
+ "breakdown": {k: v.model_dump() if hasattr(v, 'model_dump') else v for k, v in effort_result.breakdown.items()},
 
137
  "stats": effort_result.stats,
138
  "infeasible_pairs": list(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else [],
139
  }
140
 
 
141
  log_entry = _create_decision_log(
142
+ agent_name="ML_EFFORT", step_type="MATRIX_GENERATION",
 
143
  input_snapshot=ml_agent.get_input_snapshot(drivers, routes),
144
+ output_snapshot={**ml_agent.get_output_snapshot(effort_result), "num_infeasible_ev_pairs": len(effort_result.infeasible_pairs) if effort_result.infeasible_pairs else 0},
 
 
 
145
  )
146
 
 
147
  _publish_event_sync(run_id, "ML_EFFORT", "MATRIX_GENERATION", "COMPLETED", {
148
  "min_effort": effort_result.stats.get("min", 0),
149
  "max_effort": effort_result.stats.get("max", 0),
150
  "avg_effort": effort_result.stats.get("avg", 0),
151
  })
152
 
153
+ return {"effort_matrix": effort_dict, "decision_logs": state.decision_logs + [log_entry]}
 
 
 
154
 
155
 
156
  # =============================================================================
 
158
  # =============================================================================
159
 
160
  def route_planner_node(state: AllocationState) -> Dict[str, Any]:
161
+ """LangGraph node #2: Route Planner Agent - Proposal 1 (OR-Tools optimization)."""
 
 
 
 
 
162
  run_id = state.allocation_run_id
163
 
 
164
  _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "STARTED", {
165
+ "num_drivers": len(state.driver_models), "num_routes": len(state.route_models),
 
166
  })
167
 
168
  planner_agent = RoutePlannerAgent()
169
+ from app.schemas.agent_schemas import EffortMatrixResult
170
 
 
 
 
 
171
  matrix = state.effort_matrix["matrix"]
172
+ stats = state.effort_matrix.get("stats") or {"min": 0, "max": 0, "avg": 0}
 
 
 
 
 
 
 
173
 
174
  effort_result = EffortMatrixResult(
175
+ matrix=matrix, driver_ids=state.effort_matrix["driver_ids"],
176
+ route_ids=state.effort_matrix["route_ids"], breakdown={}, stats=stats,
 
 
 
177
  infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
178
  )
179
 
 
180
  recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
 
 
181
  drivers = [ModelWrapper(d) for d in state.driver_models]
182
  routes = [ModelWrapper(r) for r in state.route_models]
183
 
 
184
  proposal1 = planner_agent.plan(
185
+ effort_result=effort_result, drivers=drivers, routes=routes,
 
 
186
  recovery_targets=state.recovery_targets or {},
187
+ recovery_penalty_weight=recovery_penalty_weight, proposal_number=1,
 
188
  )
189
 
 
190
  proposal_dict = {
191
  "allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal1.allocation],
192
+ "total_effort": proposal1.total_effort, "avg_effort": proposal1.avg_effort,
193
+ "solver_status": proposal1.solver_status, "proposal_number": proposal1.proposal_number,
 
 
194
  "per_driver_effort": proposal1.per_driver_effort,
195
  }
196
 
 
197
  log_entry = _create_decision_log(
198
+ agent_name="ROUTE_PLANNER", step_type="PROPOSAL_1",
 
199
  input_snapshot=planner_agent.get_input_snapshot(effort_result),
200
  output_snapshot=planner_agent.get_output_snapshot(proposal1),
201
  )
202
 
 
203
  _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_1", "COMPLETED", {
204
+ "total_effort": proposal1.total_effort, "num_assignments": len(proposal1.allocation),
 
205
  "solver_status": proposal1.solver_status,
206
  })
207
 
208
+ return {"route_proposal_1": proposal_dict, "decision_logs": state.decision_logs + [log_entry]}
 
 
 
209
 
210
 
211
  # =============================================================================
 
213
  # =============================================================================
214
 
215
  def fairness_check_node(state: AllocationState) -> Dict[str, Any]:
216
+ """LangGraph node #3: Fairness Manager Agent — evaluates Gini/stddev/max_gap."""
 
 
 
 
 
217
  run_id = state.allocation_run_id
 
218
 
219
+ _publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_1", "STARTED", {"proposal_number": 1})
 
 
 
220
 
 
221
  thresholds = FairnessThresholds(
222
  gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
223
  stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
 
225
  )
226
 
227
  fairness_agent = FairnessManagerAgent(thresholds=thresholds)
 
 
228
  from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
229
 
230
+ proposal_to_check = state.route_proposal_1
 
 
231
  plan_result = RoutePlanResult(
232
  allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
233
  total_effort=proposal_to_check["total_effort"],
234
+ avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / max(len(proposal_to_check["allocation"]), 1)),
235
  solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
236
+ proposal_number=1, per_driver_effort=proposal_to_check["per_driver_effort"],
 
237
  )
238
 
239
+ fairness_result = fairness_agent.check(plan_result, proposal_number=1)
 
240
 
 
241
  fairness_dict = {
242
+ "status": fairness_result.status, "proposal_number": fairness_result.proposal_number,
 
243
  "metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
244
+ "avg_effort": fairness_result.metrics.avg_effort, "std_dev": fairness_result.metrics.std_dev,
245
+ "gini_index": fairness_result.metrics.gini_index, "max_effort": fairness_result.metrics.max_effort,
246
+ "min_effort": fairness_result.metrics.min_effort, "max_gap": fairness_result.metrics.max_gap,
 
 
 
247
  },
248
  "recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
249
  }
250
 
 
251
  log_entry = _create_decision_log(
252
+ agent_name="FAIRNESS_MANAGER", step_type="FAIRNESS_CHECK_PROPOSAL_1",
 
253
  input_snapshot=fairness_agent.get_input_snapshot(plan_result),
254
  output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
255
  )
256
 
257
+ _publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_1", "COMPLETED", {
258
+ "status": fairness_result.status, "gini_index": fairness_dict["metrics"]["gini_index"],
 
 
 
259
  })
260
 
261
+ return {"fairness_check_1": fairness_dict, "decision_logs": state.decision_logs + [log_entry]}
 
 
 
 
 
 
 
 
 
 
262
 
263
 
264
  def fairness_check_2_node(state: AllocationState) -> Dict[str, Any]:
265
+ """LangGraph node for second fairness check (after re-optimization)."""
 
 
 
 
 
266
  run_id = state.allocation_run_id
 
267
 
268
+ _publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_2", "STARTED", {"proposal_number": 2})
 
 
 
269
 
 
270
  thresholds = FairnessThresholds(
271
  gini_threshold=state.config_used.get("gini_threshold", 0.33) if state.config_used else 0.33,
272
  stddev_threshold=state.config_used.get("stddev_threshold", 25.0) if state.config_used else 25.0,
 
274
  )
275
 
276
  fairness_agent = FairnessManagerAgent(thresholds=thresholds)
 
 
277
  from app.schemas.agent_schemas import RoutePlanResult, AllocationItem
278
 
 
279
  proposal_to_check = state.route_proposal_2
 
280
  plan_result = RoutePlanResult(
281
  allocation=[AllocationItem(**a) for a in proposal_to_check["allocation"]],
282
  total_effort=proposal_to_check["total_effort"],
283
+ avg_effort=proposal_to_check.get("avg_effort", proposal_to_check["total_effort"] / max(len(proposal_to_check["allocation"]), 1)),
284
  solver_status=proposal_to_check.get("solver_status", "OPTIMAL"),
285
+ proposal_number=2, per_driver_effort=proposal_to_check["per_driver_effort"],
 
286
  )
287
 
288
+ fairness_result = fairness_agent.check(plan_result, proposal_number=2)
 
289
 
 
290
  fairness_dict = {
291
+ "status": fairness_result.status, "proposal_number": 2,
 
292
  "metrics": fairness_result.metrics.model_dump() if hasattr(fairness_result.metrics, 'model_dump') else {
293
+ "avg_effort": fairness_result.metrics.avg_effort, "std_dev": fairness_result.metrics.std_dev,
294
+ "gini_index": fairness_result.metrics.gini_index, "max_effort": fairness_result.metrics.max_effort,
295
+ "min_effort": fairness_result.metrics.min_effort, "max_gap": fairness_result.metrics.max_gap,
 
 
 
296
  },
297
  "recommendations": fairness_result.recommendations.model_dump() if fairness_result.recommendations and hasattr(fairness_result.recommendations, 'model_dump') else None,
298
  }
299
 
 
300
  log_entry = _create_decision_log(
301
+ agent_name="FAIRNESS_MANAGER", step_type="FAIRNESS_CHECK_PROPOSAL_2",
 
302
  input_snapshot=fairness_agent.get_input_snapshot(plan_result),
303
  output_snapshot=fairness_agent.get_output_snapshot(fairness_result),
304
  )
305
 
306
+ _publish_event_sync(run_id, "FAIRNESS_MANAGER", "FAIRNESS_CHECK_2", "COMPLETED", {
307
+ "status": fairness_result.status, "gini_index": fairness_dict["metrics"]["gini_index"],
 
 
 
308
  })
309
 
310
+ return {"fairness_check_2": fairness_dict, "decision_logs": state.decision_logs + [log_entry]}
 
 
 
311
 
312
 
313
  # =============================================================================
314
+ # Node 4: Route Planner Re-optimization (Proposal 2)
315
  # =============================================================================
316
 
 
317
  def route_planner_reoptimize_node(state: AllocationState) -> Dict[str, Any]:
318
+ """LangGraph node #4: Route Planner - Proposal 2 with fairness penalties."""
 
 
 
 
 
319
  run_id = state.allocation_run_id
320
 
321
+ _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "STARTED", {"reason": "fairness_reoptimization"})
 
 
 
322
 
323
  planner_agent = RoutePlannerAgent()
 
 
324
  from app.schemas.agent_schemas import EffortMatrixResult, FairnessRecommendations
325
 
 
326
  matrix = state.effort_matrix["matrix"]
327
+ stats = state.effort_matrix.get("stats") or {"min": 0, "max": 0, "avg": 0}
 
 
 
 
 
 
 
328
 
329
  effort_result = EffortMatrixResult(
330
+ matrix=matrix, driver_ids=state.effort_matrix["driver_ids"],
331
+ route_ids=state.effort_matrix["route_ids"], breakdown={}, stats=stats,
 
 
 
332
  infeasible_pairs=list(state.effort_matrix.get("infeasible_pairs", [])),
333
  )
334
 
 
335
  recommendations_dict = state.fairness_check_1.get("recommendations")
336
  penalties = {}
 
337
  if recommendations_dict:
338
  recommendations = FairnessRecommendations(**recommendations_dict)
339
+ penalties = planner_agent.build_penalties_from_recommendations(recommendations, state.route_proposal_1["per_driver_effort"])
 
 
 
340
 
 
341
  recovery_penalty_weight = state.config_used.get("recovery_penalty_weight", 3.0) if state.config_used else 3.0
 
 
342
  drivers = [ModelWrapper(d) for d in state.driver_models]
343
  routes = [ModelWrapper(r) for r in state.route_models]
344
 
 
345
  proposal2 = planner_agent.plan(
346
+ effort_result=effort_result, drivers=drivers, routes=routes,
347
+ fairness_penalties=penalties, recovery_targets=state.recovery_targets or {},
348
+ recovery_penalty_weight=recovery_penalty_weight, proposal_number=2,
 
 
 
 
349
  )
350
 
 
351
  proposal_dict = {
352
  "allocation": [a.model_dump() if hasattr(a, 'model_dump') else a for a in proposal2.allocation],
353
+ "total_effort": proposal2.total_effort, "avg_effort": proposal2.avg_effort,
354
+ "solver_status": proposal2.solver_status, "proposal_number": 2,
 
 
355
  "per_driver_effort": proposal2.per_driver_effort,
356
  }
357
 
 
358
  log_entry = _create_decision_log(
359
+ agent_name="ROUTE_PLANNER", step_type="PROPOSAL_2",
 
360
  input_snapshot=planner_agent.get_input_snapshot(effort_result, penalties),
361
  output_snapshot=planner_agent.get_output_snapshot(proposal2),
362
  )
363
 
 
364
  _publish_event_sync(run_id, "ROUTE_PLANNER", "PROPOSAL_2", "COMPLETED", {
365
+ "total_effort": proposal2.total_effort, "solver_status": proposal2.solver_status,
 
 
366
  })
367
 
368
+ return {"route_proposal_2": proposal_dict, "decision_logs": state.decision_logs + [log_entry]}
 
 
 
 
369
 
370
 
371
  # =============================================================================
 
373
  # =============================================================================
374
 
375
  def select_final_proposal_node(state: AllocationState) -> Dict[str, Any]:
376
+ """Select best proposal based on fairness metrics comparison."""
 
 
 
 
 
377
  final_proposal = state.route_proposal_1
378
  final_fairness = state.fairness_check_1
379
 
380
  if state.route_proposal_2 and state.fairness_check_2:
 
381
  check1_metrics = state.fairness_check_1["metrics"]
382
  check2_metrics = state.fairness_check_2["metrics"]
 
 
383
  if (check2_metrics["gini_index"] <= check1_metrics["gini_index"] or
384
  check2_metrics["max_gap"] < check1_metrics["max_gap"]):
385
  final_proposal = state.route_proposal_2
386
  final_fairness = state.fairness_check_2
387
 
388
+ return {"final_proposal": final_proposal, "final_fairness": final_fairness, "final_per_driver_effort": final_proposal["per_driver_effort"]}
 
 
 
 
389
 
390
 
391
  # =============================================================================
 
393
  # =============================================================================
394
 
395
  def driver_liaison_node(state: AllocationState) -> Dict[str, Any]:
396
+ """LangGraph node #6: Driver Liaison - per-driver comfort band negotiation."""
 
 
 
 
 
397
  run_id = state.allocation_run_id
398
 
399
+ _publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "STARTED", {"num_drivers": len(state.driver_models)})
 
 
 
400
 
401
  from app.schemas.agent_schemas import AllocationItem
 
402
  liaison_agent = DriverLiaisonAgent()
403
 
404
  final_proposal = state.final_proposal or state.route_proposal_1
405
  final_fairness = state.final_fairness or state.fairness_check_1
406
 
407
+ sorted_allocations = sorted(final_proposal["allocation"], key=lambda x: x["effort"], reverse=True)
 
 
 
 
 
 
408
  driver_proposals: List[DriverAssignmentProposal] = []
409
  for rank, alloc_item in enumerate(sorted_allocations, start=1):
410
  driver_proposals.append(DriverAssignmentProposal(
411
+ driver_id=str(alloc_item["driver_id"]), route_id=str(alloc_item["route_id"]),
412
+ effort=alloc_item["effort"], rank_in_team=rank,
 
 
413
  ))
414
 
 
415
  metrics = final_fairness["metrics"]
 
 
 
 
416
  driver_context_objs: Dict[str, DriverContext] = {}
417
+ for driver_id, context_dict in (state.driver_contexts or {}).items():
418
  driver_context_objs[driver_id] = DriverContext(**context_dict)
419
 
 
420
  negotiation_result = liaison_agent.run_for_all_drivers(
421
+ proposals=driver_proposals, driver_contexts=driver_context_objs,
422
+ effort_matrix=state.effort_matrix["matrix"], driver_ids=state.effort_matrix["driver_ids"],
 
 
423
  route_ids=state.effort_matrix["route_ids"],
424
+ global_avg_effort=metrics["avg_effort"], global_std_effort=metrics["std_dev"],
 
425
  )
426
 
 
427
  liaison_dict = {
428
  "decisions": [d.model_dump() if hasattr(d, 'model_dump') else d for d in negotiation_result.decisions],
429
  "num_accept": negotiation_result.num_accept,
 
431
  "num_force_accept": negotiation_result.num_force_accept,
432
  }
433
 
 
434
  log_entry = _create_decision_log(
435
+ agent_name="DRIVER_LIAISON", step_type="NEGOTIATION_DECISIONS",
436
+ input_snapshot=liaison_agent.get_input_snapshot(driver_proposals, metrics["avg_effort"], metrics["std_dev"]),
 
 
 
 
 
437
  output_snapshot=liaison_agent.get_output_snapshot(negotiation_result),
438
  )
439
 
 
440
  _publish_event_sync(run_id, "DRIVER_LIAISON", "NEGOTIATION", "COMPLETED", {
441
+ "num_accept": negotiation_result.num_accept, "num_counter": negotiation_result.num_counter,
 
 
442
  })
443
 
444
+ return {"liaison_feedback": liaison_dict, "decision_logs": state.decision_logs + [log_entry]}
 
 
 
445
 
446
 
447
  # =============================================================================
 
449
  # =============================================================================
450
 
451
  def final_resolution_node(state: AllocationState) -> Dict[str, Any]:
452
+ """LangGraph node #7: Final Resolution - resolves COUNTER decisions via swaps."""
 
 
 
 
 
453
  run_id = state.allocation_run_id
454
  from app.schemas.agent_schemas import RoutePlanResult, AllocationItem, FairnessMetrics, DriverLiaisonDecision
455
 
456
+ counter_decisions = [d for d in state.liaison_feedback["decisions"] if d["decision"] == "COUNTER"]
 
 
 
 
457
 
458
  if not counter_decisions:
459
+ _publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {"reason": "no_counters", "swaps_applied": 0})
460
+ return {"resolution_result": {"swaps_applied": []}}
 
 
 
 
 
 
 
461
 
462
+ _publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "STARTED", {"num_counters": len(counter_decisions)})
 
 
 
463
 
464
  resolution_agent = FinalResolutionAgent()
 
 
465
  final_proposal = state.final_proposal or state.route_proposal_1
466
  final_fairness = state.final_fairness or state.fairness_check_1
467
 
468
  approved_proposal = RoutePlanResult(
469
  allocation=[AllocationItem(**a) for a in final_proposal["allocation"]],
470
  total_effort=final_proposal["total_effort"],
471
+ avg_effort=final_proposal.get("avg_effort", final_proposal["total_effort"] / max(len(final_proposal["allocation"]), 1)),
472
  solver_status=final_proposal.get("solver_status", "OPTIMAL"),
473
  proposal_number=final_proposal["proposal_number"],
474
  per_driver_effort=final_proposal["per_driver_effort"],
475
  )
476
 
477
  decisions = [DriverLiaisonDecision(**d) for d in state.liaison_feedback["decisions"]]
 
478
  current_metrics = FairnessMetrics(**final_fairness["metrics"])
479
 
 
480
  resolution_result = resolution_agent.resolve_counters(
481
+ approved_proposal=approved_proposal, decisions=decisions,
 
482
  effort_matrix=state.effort_matrix["matrix"],
483
+ driver_ids=state.effort_matrix["driver_ids"], route_ids=state.effort_matrix["route_ids"],
 
484
  current_metrics=current_metrics,
485
  )
486
 
 
487
  resolution_dict = {
488
  "swaps_applied": [s.model_dump() if hasattr(s, 'model_dump') else s for s in resolution_result.swaps_applied],
489
+ "allocation": resolution_result.allocation,
490
  "per_driver_effort": resolution_result.per_driver_effort,
491
  "metrics": resolution_result.metrics,
492
  }
493
 
 
494
  log_entry = _create_decision_log(
495
+ agent_name="FINAL_RESOLUTION", step_type="SWAP_RESOLUTION",
496
+ input_snapshot=resolution_agent.get_input_snapshot(len(counter_decisions), current_metrics, final_fairness["metrics"]["avg_effort"]),
 
 
 
 
 
497
  output_snapshot=resolution_agent.get_output_snapshot(resolution_result),
498
  )
499
 
500
+ _publish_event_sync(run_id, "FINAL_RESOLUTION", "SWAP_RESOLUTION", "COMPLETED", {"swaps_applied": len(resolution_result.swaps_applied)})
 
 
 
501
 
502
+ updated_effort = state.final_per_driver_effort.copy() if state.final_per_driver_effort else {}
 
503
  if resolution_result.swaps_applied:
504
  updated_effort = resolution_result.per_driver_effort
505
 
506
+ return {"resolution_result": resolution_dict, "final_per_driver_effort": updated_effort, "decision_logs": state.decision_logs + [log_entry]}
 
 
 
 
507
 
508
 
509
  # =============================================================================
 
511
  # =============================================================================
512
 
513
  def explainability_node(state: AllocationState) -> Dict[str, Any]:
514
+ """LangGraph node #8: Explainability Agent — generates per-driver explanations."""
 
 
 
 
 
515
  run_id = state.allocation_run_id
516
 
517
+ _publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "STARTED", {"num_drivers": len(state.driver_models)})
 
 
 
518
 
519
  explain_agent = ExplainabilityAgent()
 
520
  final_proposal = state.final_proposal or state.route_proposal_1
521
  final_fairness = state.final_fairness or state.fairness_check_1
522
  final_per_driver_effort = state.final_per_driver_effort or final_proposal["per_driver_effort"]
 
524
  metrics = final_fairness["metrics"]
525
  avg_effort = metrics["avg_effort"]
526
 
 
527
  route_by_id = {str(r["id"]): r for r in state.route_models}
528
  driver_by_id = {str(d["id"]): d for d in state.driver_models}
529
+ route_dict_by_id = {str(r["id"]): rd for r, rd in zip(state.route_models, state.route_dicts)} if state.route_dicts else {}
530
 
531
+ sorted_efforts = sorted(final_per_driver_effort.items(), key=lambda x: x[1], reverse=True)
 
 
 
 
 
532
  rank_by_driver = {did: idx + 1 for idx, (did, _) in enumerate(sorted_efforts)}
533
  num_drivers = len(final_per_driver_effort)
534
 
 
535
  liaison_by_driver = {}
536
  if state.liaison_feedback:
537
  for decision in state.liaison_feedback["decisions"]:
538
  liaison_by_driver[decision["driver_id"]] = decision
539
 
 
540
  swapped_drivers = set()
541
  if state.resolution_result and state.resolution_result.get("swaps_applied"):
542
  for swap in state.resolution_result["swaps_applied"]:
543
+ swapped_drivers.add(swap.get("driver_a", ""))
544
+ swapped_drivers.add(swap.get("driver_b", ""))
545
 
546
  explanations: Dict[str, Dict[str, Any]] = {}
547
  category_counts: Dict[str, int] = {}
 
552
 
553
  driver = driver_by_id.get(driver_id_str, {})
554
  route = route_by_id.get(route_id_str, {})
 
555
 
 
556
  effort = final_per_driver_effort.get(driver_id_str, alloc_item["effort"])
557
  fairness_score = calculate_fairness_score(effort, avg_effort)
558
+ driver_context = (state.driver_contexts or {}).get(driver_id_str, {})
 
 
559
  history_efforts = [driver_context.get("recent_avg_effort", avg_effort)] if driver_context else []
560
  history_hard_days = driver_context.get("recent_hard_days", 0) if driver_context else 0
561
 
 
562
  breakdown_key = f"{driver_id_str}:{route_id_str}"
563
  effort_breakdown_data = state.effort_matrix.get("breakdown", {}).get(breakdown_key, {})
564
  effort_breakdown = {
 
567
  "time_pressure": effort_breakdown_data.get("time_pressure", 0),
568
  }
569
 
 
570
  liaison_decision = liaison_by_driver.get(driver_id_str)
571
+ is_recovery = history_hard_days >= 3 and effort < avg_effort * 0.85
572
 
 
 
 
 
 
 
 
573
  explain_input = DriverExplanationInput(
574
+ driver_id=driver_id_str, driver_name=driver.get("name", "Driver"),
575
+ num_drivers=num_drivers, today_effort=effort,
 
 
576
  today_rank=rank_by_driver.get(driver_id_str, num_drivers),
577
  route_id=route_id_str,
578
+ route_summary={"num_packages": route.get("num_packages", 0), "total_weight_kg": route.get("total_weight_kg", 0), "num_stops": route.get("num_stops", 0), "difficulty_score": route.get("route_difficulty_score", 0), "estimated_time_minutes": route.get("estimated_time_minutes", 0)},
 
 
 
 
 
 
579
  effort_breakdown=effort_breakdown,
580
+ global_avg_effort=avg_effort, global_std_effort=metrics["std_dev"],
581
+ global_gini_index=metrics["gini_index"], global_max_gap=metrics["max_gap"],
 
 
582
  history_efforts_last_7_days=history_efforts,
583
+ history_hard_days_last_7=history_hard_days, is_recovery_day=is_recovery,
584
+ had_manual_override=False,
 
585
  liaison_decision=liaison_decision["decision"] if liaison_decision else None,
586
  swap_applied=driver_id_str in swapped_drivers,
587
  )
588
 
 
589
  explain_output = explain_agent.build_explanation_for_driver(explain_input)
 
 
590
  category_counts[explain_output.category] = category_counts.get(explain_output.category, 0) + 1
 
591
  explanations[driver_id_str] = {
592
  "driver_explanation": explain_output.driver_explanation,
593
  "admin_explanation": explain_output.admin_explanation,
594
  "category": explain_output.category,
595
  }
596
 
 
597
  log_entry = _create_decision_log(
598
+ agent_name="EXPLAINABILITY", step_type="EXPLANATIONS_GENERATED",
599
+ input_snapshot=explain_agent.get_input_snapshot(num_drivers=num_drivers, avg_effort=avg_effort, std_effort=metrics["std_dev"], gini_index=metrics["gini_index"], category_counts=category_counts),
600
+ output_snapshot=explain_agent.get_output_snapshot(total_explanations=len(explanations), category_counts=category_counts),
 
 
 
 
 
 
 
 
 
 
601
  )
602
 
603
+ _publish_event_sync(run_id, "EXPLAINABILITY", "EXPLANATIONS", "COMPLETED", {"total_explanations": len(explanations), "categories": category_counts})
 
 
 
 
604
 
605
+ return {"explanations": explanations, "decision_logs": state.decision_logs + [log_entry]}
 
 
 
606
 
607
 
608
  # =============================================================================
 
610
  # =============================================================================
611
 
612
  def should_reoptimize(state: AllocationState) -> str:
613
+ """Conditional: re-optimize if fairness check 1 says REOPTIMIZE and no proposal 2 yet."""
 
 
 
 
 
 
614
  if state.fairness_check_1 and state.fairness_check_1.get("status") == "REOPTIMIZE":
615
  if not state.route_proposal_2:
616
  return "reoptimize"
 
618
 
619
 
620
  def has_counter_decisions(state: AllocationState) -> str:
621
+ """Conditional: check if any COUNTER decisions need resolution."""
 
 
 
 
 
 
622
  if state.liaison_feedback:
623
+ if sum(1 for d in state.liaison_feedback["decisions"] if d["decision"] == "COUNTER") > 0:
 
 
 
 
624
  return "resolve"
625
  return "skip"