RoyAalekh commited on
Commit
802c366
·
1 Parent(s): 8d2e8fa

Add episode-level reward helper for RL training

Browse files
Files changed (2) hide show
  1. rl/rewards.py +127 -0
  2. rl/training.py +20 -12
rl/rewards.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared reward helper utilities for RL agents.
2
+
3
+ The helper operates on episode-level statistics so that reward shaping
4
+ reflects system-wide outcomes (disposal rate, gap compliance, urgent
5
+ case latency, and fairness across cases).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections import defaultdict
11
+ from dataclasses import dataclass, field
12
+ from typing import Dict, Iterable, Optional
13
+
14
+ import numpy as np
15
+
16
+ from scheduler.core.case import Case
17
+
18
+
19
+ @dataclass
20
+ class EpisodeRewardHelper:
21
+ """Aggregates episode metrics and computes shaped rewards."""
22
+
23
+ total_cases: int
24
+ target_gap_days: int = 30
25
+ max_urgent_latency: int = 60
26
+ disposal_weight: float = 4.0
27
+ gap_weight: float = 1.5
28
+ urgent_weight: float = 2.0
29
+ fairness_weight: float = 1.0
30
+ _disposed_cases: int = 0
31
+ _hearing_counts: Dict[str, int] = field(default_factory=lambda: defaultdict(int))
32
+ _urgent_latencies: list[float] = field(default_factory=list)
33
+
34
+ def _base_outcome_reward(self, case: Case, was_scheduled: bool, hearing_outcome: str) -> float:
35
+ """Preserve the original per-case shaping signals."""
36
+
37
+ reward = 0.0
38
+ if not was_scheduled:
39
+ return reward
40
+
41
+ # Base scheduling reward (small positive for taking action)
42
+ reward += 0.5
43
+
44
+ # Hearing outcome rewards
45
+ lower_outcome = hearing_outcome.lower()
46
+ if "disposal" in lower_outcome or "judgment" in lower_outcome or "settlement" in lower_outcome:
47
+ reward += 10.0 # Major positive for disposal
48
+ elif "progress" in lower_outcome and "adjourn" not in lower_outcome:
49
+ reward += 3.0 # Progress without disposal
50
+ elif "adjourn" in lower_outcome:
51
+ reward -= 3.0 # Negative for adjournment
52
+
53
+ # Urgency bonus
54
+ if case.is_urgent:
55
+ reward += 2.0
56
+
57
+ # Ripeness penalty
58
+ if hasattr(case, "ripeness_status") and case.ripeness_status not in ["RIPE", "UNKNOWN"]:
59
+ reward -= 4.0
60
+
61
+ # Long pending bonus (>365 days)
62
+ if case.age_days and case.age_days > 365:
63
+ reward += 2.0
64
+
65
+ return reward
66
+
67
+ def _fairness_score(self) -> float:
68
+ """Reward higher uniformity in hearing distribution."""
69
+
70
+ counts: Iterable[int] = self._hearing_counts.values()
71
+ if not counts:
72
+ return 0.0
73
+
74
+ counts_array = np.array(list(counts), dtype=float)
75
+ mean = np.mean(counts_array)
76
+ if mean == 0:
77
+ return 0.0
78
+
79
+ dispersion = np.std(counts_array) / (mean + 1e-6)
80
+ # Lower dispersion -> better fairness. Convert to reward in [0, 1].
81
+ fairness = max(0.0, 1.0 - dispersion)
82
+ return fairness
83
+
84
+ def compute_case_reward(
85
+ self,
86
+ case: Case,
87
+ was_scheduled: bool,
88
+ hearing_outcome: str,
89
+ current_date,
90
+ previous_gap_days: Optional[int] = None,
91
+ ) -> float:
92
+ """Compute reward using both local and episode-level signals."""
93
+
94
+ reward = self._base_outcome_reward(case, was_scheduled, hearing_outcome)
95
+
96
+ if not was_scheduled:
97
+ return reward
98
+
99
+ # Track disposals
100
+ if "disposal" in hearing_outcome.lower() or getattr(case, "is_disposed", False):
101
+ self._disposed_cases += 1
102
+
103
+ # Track hearing counts for fairness
104
+ self._hearing_counts[case.case_id] = case.hearing_count or self._hearing_counts[case.case_id] + 1
105
+
106
+ # Track urgent latencies
107
+ if case.is_urgent:
108
+ self._urgent_latencies.append(case.age_days or 0)
109
+
110
+ # Episode-level components
111
+ disposal_rate = (self._disposed_cases / self.total_cases) if self.total_cases else 0.0
112
+ reward += self.disposal_weight * disposal_rate
113
+
114
+ if previous_gap_days is not None:
115
+ gap_score = max(0.0, 1.0 - (previous_gap_days / self.target_gap_days))
116
+ reward += self.gap_weight * gap_score
117
+
118
+ if self._urgent_latencies:
119
+ avg_latency = float(np.mean(self._urgent_latencies))
120
+ latency_score = max(0.0, 1.0 - (avg_latency / self.max_urgent_latency))
121
+ reward += self.urgent_weight * latency_score
122
+
123
+ fairness = self._fairness_score()
124
+ reward += self.fairness_weight * fairness
125
+
126
+ return reward
127
+
rl/training.py CHANGED
@@ -13,7 +13,8 @@ import random
13
  from scheduler.data.case_generator import CaseGenerator
14
  from scheduler.simulation.engine import CourtSim, CourtSimConfig
15
  from scheduler.core.case import Case, CaseStatus
16
- from .simple_agent import TabularQAgent, CaseState
 
17
 
18
 
19
  class RLTrainingEnvironment:
@@ -32,6 +33,7 @@ class RLTrainingEnvironment:
32
  self.horizon_days = horizon_days
33
  self.current_date = start_date
34
  self.episode_rewards = []
 
35
 
36
  def reset(self) -> List[Case]:
37
  """Reset environment for new training episode.
@@ -42,6 +44,7 @@ class RLTrainingEnvironment:
42
  """
43
  self.current_date = self.start_date
44
  self.episode_rewards = []
 
45
  return self.cases.copy()
46
 
47
  def step(self, agent_decisions: Dict[str, int]) -> Tuple[List[Case], Dict[str, float], bool]:
@@ -57,18 +60,23 @@ class RLTrainingEnvironment:
57
  rewards = {}
58
 
59
  # For each case that agent decided to schedule
60
- scheduled_cases = [case for case in self.cases
61
  if case.case_id in agent_decisions and agent_decisions[case.case_id] == 1]
62
 
63
  # Simulate hearing outcomes for scheduled cases
64
  for case in scheduled_cases:
65
  if case.is_disposed:
66
  continue
67
-
68
  # Simulate hearing outcome based on stage transition probabilities
69
  outcome = self._simulate_hearing_outcome(case)
70
  was_heard = "heard" in outcome.lower()
71
-
 
 
 
 
 
72
  # Always record the hearing
73
  case.record_hearing(self.current_date, was_heard=was_heard, outcome=outcome)
74
 
@@ -83,7 +91,13 @@ class RLTrainingEnvironment:
83
  # If adjourned, case stays in same stage
84
 
85
  # Compute reward for this case
86
- rewards[case.case_id] = self._compute_reward(case, outcome)
 
 
 
 
 
 
87
 
88
  # Update case ages
89
  for case in self.cases:
@@ -131,13 +145,7 @@ class RLTrainingEnvironment:
131
  # Default progression
132
  return "ARGUMENTS"
133
 
134
- def _compute_reward(self, case: Case, outcome: str) -> float:
135
- """Compute reward based on case and outcome."""
136
- agent = TabularQAgent() # Use for reward computation
137
- return agent.compute_reward(case, was_scheduled=True, hearing_outcome=outcome)
138
-
139
-
140
- def train_agent(agent: TabularQAgent, episodes: int = 100,
141
  cases_per_episode: int = 1000,
142
  episode_length: int = 60,
143
  verbose: bool = True) -> Dict:
 
13
  from scheduler.data.case_generator import CaseGenerator
14
  from scheduler.simulation.engine import CourtSim, CourtSimConfig
15
  from scheduler.core.case import Case, CaseStatus
16
+ from .simple_agent import TabularQAgent
17
+ from .rewards import EpisodeRewardHelper
18
 
19
 
20
  class RLTrainingEnvironment:
 
33
  self.horizon_days = horizon_days
34
  self.current_date = start_date
35
  self.episode_rewards = []
36
+ self.reward_helper = EpisodeRewardHelper(total_cases=len(cases))
37
 
38
  def reset(self) -> List[Case]:
39
  """Reset environment for new training episode.
 
44
  """
45
  self.current_date = self.start_date
46
  self.episode_rewards = []
47
+ self.reward_helper = EpisodeRewardHelper(total_cases=len(self.cases))
48
  return self.cases.copy()
49
 
50
  def step(self, agent_decisions: Dict[str, int]) -> Tuple[List[Case], Dict[str, float], bool]:
 
60
  rewards = {}
61
 
62
  # For each case that agent decided to schedule
63
+ scheduled_cases = [case for case in self.cases
64
  if case.case_id in agent_decisions and agent_decisions[case.case_id] == 1]
65
 
66
  # Simulate hearing outcomes for scheduled cases
67
  for case in scheduled_cases:
68
  if case.is_disposed:
69
  continue
70
+
71
  # Simulate hearing outcome based on stage transition probabilities
72
  outcome = self._simulate_hearing_outcome(case)
73
  was_heard = "heard" in outcome.lower()
74
+
75
+ # Track gap relative to previous hearing for reward shaping
76
+ previous_gap = None
77
+ if case.last_hearing_date:
78
+ previous_gap = max(0, (self.current_date - case.last_hearing_date).days)
79
+
80
  # Always record the hearing
81
  case.record_hearing(self.current_date, was_heard=was_heard, outcome=outcome)
82
 
 
91
  # If adjourned, case stays in same stage
92
 
93
  # Compute reward for this case
94
+ rewards[case.case_id] = self.reward_helper.compute_case_reward(
95
+ case,
96
+ was_scheduled=True,
97
+ hearing_outcome=outcome,
98
+ current_date=self.current_date,
99
+ previous_gap_days=previous_gap,
100
+ )
101
 
102
  # Update case ages
103
  for case in self.cases:
 
145
  # Default progression
146
  return "ARGUMENTS"
147
 
148
+ def train_agent(agent: TabularQAgent, episodes: int = 100,
 
 
 
 
 
 
149
  cases_per_episode: int = 1000,
150
  episode_length: int = 60,
151
  verbose: bool = True) -> Dict: