RoyAalekh commited on
Commit
a88786e
·
1 Parent(s): 8d2e8fa

Align RL training with scheduling algorithm constraints

Browse files
Files changed (3) hide show
  1. rl/config.py +21 -0
  2. rl/simple_agent.py +25 -7
  3. rl/training.py +275 -135
rl/config.py CHANGED
@@ -17,6 +17,14 @@ class RLTrainingConfig:
17
  episodes: int = 100
18
  cases_per_episode: int = 1000
19
  episode_length_days: int = 60
 
 
 
 
 
 
 
 
20
 
21
  # Q-learning hyperparameters
22
  learning_rate: float = 0.15
@@ -48,6 +56,19 @@ class RLTrainingConfig:
48
  if self.cases_per_episode < 1:
49
  raise ValueError(f"cases_per_episode must be >= 1, got {self.cases_per_episode}")
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  @dataclass
53
  class PolicyConfig:
 
17
  episodes: int = 100
18
  cases_per_episode: int = 1000
19
  episode_length_days: int = 60
20
+
21
+ # Courtroom + allocation constraints
22
+ courtrooms: int = 5
23
+ daily_capacity_per_courtroom: int = 151
24
+ cap_daily_allocations: bool = True
25
+ max_daily_allocations: int | None = None # Optional hard cap (overrides computed capacity)
26
+ enforce_min_gap: bool = True
27
+ apply_judge_preferences: bool = True
28
 
29
  # Q-learning hyperparameters
30
  learning_rate: float = 0.15
 
56
  if self.cases_per_episode < 1:
57
  raise ValueError(f"cases_per_episode must be >= 1, got {self.cases_per_episode}")
58
 
59
+ if self.courtrooms < 1:
60
+ raise ValueError(f"courtrooms must be >= 1, got {self.courtrooms}")
61
+
62
+ if self.daily_capacity_per_courtroom < 1:
63
+ raise ValueError(
64
+ f"daily_capacity_per_courtroom must be >= 1, got {self.daily_capacity_per_courtroom}"
65
+ )
66
+
67
+ if self.max_daily_allocations is not None and self.max_daily_allocations < 1:
68
+ raise ValueError(
69
+ f"max_daily_allocations must be >= 1 when provided, got {self.max_daily_allocations}"
70
+ )
71
+
72
 
73
  @dataclass
74
  class PolicyConfig:
rl/simple_agent.py CHANGED
@@ -18,15 +18,19 @@ from scheduler.core.case import Case
18
 
19
  @dataclass
20
  class CaseState:
21
- """6-dimensional state representation for a case."""
 
22
  stage_encoded: int # 0-7 for different stages
23
  age_days: float # normalized 0-1
24
- days_since_last: float # normalized 0-1
25
  urgency: int # 0 or 1
26
  ripe: int # 0 or 1
27
  hearing_count: float # normalized 0-1
28
-
29
- def to_tuple(self) -> Tuple[int, int, int, int, int, int]:
 
 
 
30
  """Convert to tuple for use as dict key."""
31
  return (
32
  self.stage_encoded,
@@ -34,7 +38,10 @@ class CaseState:
34
  min(9, int(self.days_since_last * 20)), # discretize to 20 bins, cap at 9
35
  self.urgency,
36
  self.ripe,
37
- min(9, int(self.hearing_count * 20)) # discretize to 20 bins, cap at 9
 
 
 
38
  )
39
 
40
 
@@ -77,7 +84,15 @@ class TabularQAgent:
77
  self.states_visited = set()
78
  self.total_updates = 0
79
 
80
- def extract_state(self, case: Case, current_date) -> CaseState:
 
 
 
 
 
 
 
 
81
  """Extract 6D state representation from a case.
82
 
83
  Args:
@@ -118,7 +133,10 @@ class TabularQAgent:
118
  days_since_last=days_since,
119
  urgency=urgency,
120
  ripe=ripe,
121
- hearing_count=hearing_count
 
 
 
122
  )
123
 
124
  def get_action(self, state: CaseState, training: bool = False) -> int:
 
18
 
19
  @dataclass
20
  class CaseState:
21
+ """Expanded state representation for a case with environment context."""
22
+
23
  stage_encoded: int # 0-7 for different stages
24
  age_days: float # normalized 0-1
25
+ days_since_last: float # normalized 0-1
26
  urgency: int # 0 or 1
27
  ripe: int # 0 or 1
28
  hearing_count: float # normalized 0-1
29
+ capacity_ratio: float # normalized 0-1 (remaining capacity for the day)
30
+ min_gap_days: int # encoded min gap rule in effect
31
+ preference_score: float # normalized 0-1 preference alignment
32
+
33
+ def to_tuple(self) -> Tuple[int, int, int, int, int, int, int, int, int]:
34
  """Convert to tuple for use as dict key."""
35
  return (
36
  self.stage_encoded,
 
38
  min(9, int(self.days_since_last * 20)), # discretize to 20 bins, cap at 9
39
  self.urgency,
40
  self.ripe,
41
+ min(9, int(self.hearing_count * 20)), # discretize to 20 bins, cap at 9
42
+ min(9, int(self.capacity_ratio * 10)),
43
+ min(30, self.min_gap_days),
44
+ min(9, int(self.preference_score * 10))
45
  )
46
 
47
 
 
84
  self.states_visited = set()
85
  self.total_updates = 0
86
 
87
+ def extract_state(
88
+ self,
89
+ case: Case,
90
+ current_date,
91
+ *,
92
+ capacity_ratio: float = 1.0,
93
+ min_gap_days: int = 7,
94
+ preference_score: float = 0.0,
95
+ ) -> CaseState:
96
  """Extract 6D state representation from a case.
97
 
98
  Args:
 
133
  days_since_last=days_since,
134
  urgency=urgency,
135
  ripe=ripe,
136
+ hearing_count=hearing_count,
137
+ capacity_ratio=max(0.0, min(1.0, capacity_ratio)),
138
+ min_gap_days=max(0, min_gap_days),
139
+ preference_score=max(0.0, min(1.0, preference_score))
140
  )
141
 
142
  def get_action(self, state: CaseState, training: bool = False) -> int:
rl/training.py CHANGED
@@ -6,36 +6,96 @@ case prioritization policies through simulation-based rewards.
6
 
7
  import numpy as np
8
  from pathlib import Path
9
- from typing import List, Tuple, Dict
10
  from datetime import date, timedelta
11
  import random
12
 
13
  from scheduler.data.case_generator import CaseGenerator
14
- from scheduler.simulation.engine import CourtSim, CourtSimConfig
15
  from scheduler.core.case import Case, CaseStatus
 
 
 
 
 
 
16
  from .simple_agent import TabularQAgent, CaseState
 
 
 
 
 
 
17
 
18
 
19
  class RLTrainingEnvironment:
20
  """Training environment for RL agent using court simulation."""
21
-
22
- def __init__(self, cases: List[Case], start_date: date, horizon_days: int = 90):
 
 
 
 
 
 
 
23
  """Initialize training environment.
24
-
25
  Args:
26
  cases: List of cases to simulate
27
  start_date: Simulation start date
28
  horizon_days: Training episode length in days
 
 
29
  """
30
  self.cases = cases
31
  self.start_date = start_date
32
  self.horizon_days = horizon_days
33
  self.current_date = start_date
34
  self.episode_rewards = []
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def reset(self) -> List[Case]:
37
  """Reset environment for new training episode.
38
-
39
  Note: In practice, train_agent() generates fresh cases per episode,
40
  so case state doesn't need resetting. This method just resets
41
  environment state (date, rewards).
@@ -43,70 +103,94 @@ class RLTrainingEnvironment:
43
  self.current_date = self.start_date
44
  self.episode_rewards = []
45
  return self.cases.copy()
46
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def step(self, agent_decisions: Dict[str, int]) -> Tuple[List[Case], Dict[str, float], bool]:
48
- """Execute one day of simulation with agent decisions.
49
-
50
- Args:
51
- agent_decisions: Dict mapping case_id to action (0=skip, 1=schedule)
52
-
53
- Returns:
54
- (updated_cases, rewards, episode_done)
55
- """
56
- # Simulate one day with agent decisions
57
- rewards = {}
58
-
59
- # For each case that agent decided to schedule
60
- scheduled_cases = [case for case in self.cases
61
- if case.case_id in agent_decisions and agent_decisions[case.case_id] == 1]
62
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # Simulate hearing outcomes for scheduled cases
64
  for case in scheduled_cases:
65
  if case.is_disposed:
66
  continue
67
-
68
- # Simulate hearing outcome based on stage transition probabilities
69
  outcome = self._simulate_hearing_outcome(case)
70
  was_heard = "heard" in outcome.lower()
71
-
72
- # Always record the hearing
73
  case.record_hearing(self.current_date, was_heard=was_heard, outcome=outcome)
74
-
75
  if was_heard:
76
- # Check if case progressed to terminal stage
77
  if outcome in ["FINAL DISPOSAL", "SETTLEMENT", "NA"]:
78
  case.status = CaseStatus.DISPOSED
79
  case.disposal_date = self.current_date
80
  elif outcome != "ADJOURNED":
81
- # Advance to next stage
82
  case.current_stage = outcome
83
- # If adjourned, case stays in same stage
84
-
85
- # Compute reward for this case
86
  rewards[case.case_id] = self._compute_reward(case, outcome)
87
-
88
  # Update case ages
89
  for case in self.cases:
90
  case.update_age(self.current_date)
91
-
92
  # Move to next day
93
  self.current_date += timedelta(days=1)
94
  episode_done = (self.current_date - self.start_date).days >= self.horizon_days
95
-
96
  return self.cases, rewards, episode_done
97
-
98
  def _simulate_hearing_outcome(self, case: Case) -> str:
99
  """Simulate hearing outcome based on stage and case characteristics."""
100
  # Simplified outcome simulation
101
  current_stage = case.current_stage
102
-
103
  # Terminal stages - high disposal probability
104
  if current_stage in ["ORDERS / JUDGMENT", "FINAL DISPOSAL"]:
105
  if random.random() < 0.7: # 70% chance of disposal
106
  return "FINAL DISPOSAL"
107
  else:
108
  return "ADJOURNED"
109
-
110
  # Early stages more likely to adjourn
111
  if current_stage in ["PRE-ADMISSION", "ADMISSION"]:
112
  if random.random() < 0.6: # 60% adjournment rate
@@ -117,7 +201,7 @@ class RLTrainingEnvironment:
117
  return "ADMISSION"
118
  else:
119
  return "EVIDENCE"
120
-
121
  # Mid-stages
122
  if current_stage in ["EVIDENCE", "ARGUMENTS"]:
123
  if random.random() < 0.4: # 40% adjournment rate
@@ -127,202 +211,258 @@ class RLTrainingEnvironment:
127
  return "ARGUMENTS"
128
  else:
129
  return "ORDERS / JUDGMENT"
130
-
131
  # Default progression
132
  return "ARGUMENTS"
133
-
134
  def _compute_reward(self, case: Case, outcome: str) -> float:
135
  """Compute reward based on case and outcome."""
136
  agent = TabularQAgent() # Use for reward computation
137
  return agent.compute_reward(case, was_scheduled=True, hearing_outcome=outcome)
138
 
139
 
140
- def train_agent(agent: TabularQAgent, episodes: int = 100,
141
- cases_per_episode: int = 1000,
142
- episode_length: int = 60,
143
- verbose: bool = True) -> Dict:
144
- """Train RL agent using episodic simulation.
145
-
146
- Args:
147
- agent: TabularQAgent to train
148
- episodes: Number of training episodes
149
- cases_per_episode: Number of cases per episode
150
- episode_length: Episode length in days
151
- verbose: Print training progress
152
-
153
- Returns:
154
- Training statistics
155
- """
156
  training_stats = {
157
  "episodes": [],
158
  "total_rewards": [],
159
  "disposal_rates": [],
160
  "states_explored": [],
161
- "q_updates": []
162
  }
163
-
164
  if verbose:
165
- print(f"Training RL agent for {episodes} episodes...")
166
-
167
- for episode in range(episodes):
168
  # Generate fresh cases for this episode
169
  start_date = date(2024, 1, 1) + timedelta(days=episode * 10)
170
  end_date = start_date + timedelta(days=30)
171
-
172
- generator = CaseGenerator(start=start_date, end=end_date, seed=42 + episode)
173
- cases = generator.generate(cases_per_episode, stage_mix_auto=True)
174
-
 
 
 
 
175
  # Initialize training environment
176
- env = RLTrainingEnvironment(cases, start_date, episode_length)
177
-
 
 
 
 
 
 
178
  # Reset environment
179
  episode_cases = env.reset()
180
  episode_reward = 0.0
181
-
 
 
182
  # Run episode
183
- for day in range(episode_length):
184
  # Get eligible cases (not disposed, basic filtering)
185
  eligible_cases = [c for c in episode_cases if not c.is_disposed]
186
  if not eligible_cases:
187
  break
188
-
189
  # Agent makes decisions for each case
190
  agent_decisions = {}
191
  case_states = {}
192
-
193
- for case in eligible_cases[:100]: # Limit to 100 cases per day for efficiency
194
- state = agent.extract_state(case, env.current_date)
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  action = agent.get_action(state, training=True)
 
 
 
 
 
 
196
  agent_decisions[case.case_id] = action
197
  case_states[case.case_id] = state
198
-
199
  # Environment step
200
- updated_cases, rewards, done = env.step(agent_decisions)
201
-
202
  # Update Q-values based on rewards
203
  for case_id, reward in rewards.items():
204
  if case_id in case_states:
205
  state = case_states[case_id]
206
- action = agent_decisions[case_id]
207
-
208
- # Simple Q-update (could be improved with next state)
209
  agent.update_q_value(state, action, reward)
210
  episode_reward += reward
211
-
212
  if done:
213
  break
214
-
215
  # Compute episode statistics
216
  disposed_count = sum(1 for c in episode_cases if c.is_disposed)
217
  disposal_rate = disposed_count / len(episode_cases) if episode_cases else 0.0
218
-
219
  # Record statistics
220
  training_stats["episodes"].append(episode)
221
  training_stats["total_rewards"].append(episode_reward)
222
  training_stats["disposal_rates"].append(disposal_rate)
223
  training_stats["states_explored"].append(len(agent.states_visited))
224
  training_stats["q_updates"].append(agent.total_updates)
225
-
226
  # Decay exploration
227
- if episode > 0 and episode % 20 == 0:
228
- agent.epsilon = max(0.01, agent.epsilon * 0.9)
229
-
230
  if verbose and (episode + 1) % 10 == 0:
231
- print(f"Episode {episode + 1}/{episodes}: "
232
- f"Reward={episode_reward:.1f}, "
233
- f"Disposal={disposal_rate:.1%}, "
234
- f"States={len(agent.states_visited)}, "
235
- f"Epsilon={agent.epsilon:.3f}")
236
-
 
 
237
  if verbose:
238
  final_stats = agent.get_stats()
239
  print(f"\nTraining complete!")
240
  print(f"States explored: {final_stats['states_visited']}")
241
  print(f"Q-table size: {final_stats['q_table_size']}")
242
  print(f"Total updates: {final_stats['total_updates']}")
243
-
244
  return training_stats
245
 
246
 
247
- def evaluate_agent(agent: TabularQAgent, test_cases: List[Case],
248
- episodes: int = 10, episode_length: int = 90) -> Dict:
249
- """Evaluate trained agent performance.
250
-
251
- Args:
252
- agent: Trained TabularQAgent
253
- test_cases: Test cases for evaluation
254
- episodes: Number of evaluation episodes
255
- episode_length: Episode length in days
256
-
257
- Returns:
258
- Evaluation metrics
259
- """
260
  # Set agent to evaluation mode (no exploration)
261
  original_epsilon = agent.epsilon
262
  agent.epsilon = 0.0
263
-
 
 
 
264
  evaluation_stats = {
265
  "disposal_rates": [],
266
  "total_hearings": [],
267
  "avg_hearing_to_disposal": [],
268
- "utilization": []
269
  }
270
-
271
- print(f"Evaluating agent on {episodes} test episodes...")
272
-
273
- for episode in range(episodes):
 
 
 
 
 
274
  start_date = date(2024, 6, 1) + timedelta(days=episode * 10)
275
- env = RLTrainingEnvironment(test_cases.copy(), start_date, episode_length)
276
-
 
 
 
 
 
 
277
  episode_cases = env.reset()
278
  total_hearings = 0
279
-
280
  # Run evaluation episode
281
- for day in range(episode_length):
282
  eligible_cases = [c for c in episode_cases if not c.is_disposed]
283
  if not eligible_cases:
284
  break
285
-
 
 
 
286
  # Agent makes decisions (no exploration)
287
  agent_decisions = {}
288
- for case in eligible_cases[:100]:
289
- state = agent.extract_state(case, env.current_date)
 
 
 
 
 
 
 
 
290
  action = agent.get_action(state, training=False)
 
 
 
 
 
291
  agent_decisions[case.case_id] = action
292
-
293
  # Environment step
294
- updated_cases, rewards, done = env.step(agent_decisions)
295
  total_hearings += len([r for r in rewards.values() if r != 0])
296
-
297
  if done:
298
  break
299
-
300
  # Compute metrics
301
  disposed_count = sum(1 for c in episode_cases if c.is_disposed)
302
  disposal_rate = disposed_count / len(episode_cases)
303
-
304
  disposed_cases = [c for c in episode_cases if c.is_disposed]
305
  avg_hearings = np.mean([c.hearing_count for c in disposed_cases]) if disposed_cases else 0
306
-
307
  evaluation_stats["disposal_rates"].append(disposal_rate)
308
  evaluation_stats["total_hearings"].append(total_hearings)
309
  evaluation_stats["avg_hearing_to_disposal"].append(avg_hearings)
310
- evaluation_stats["utilization"].append(total_hearings / (episode_length * 151 * 5)) # 151 capacity, 5 courts
311
-
312
  # Restore original epsilon
313
  agent.epsilon = original_epsilon
314
-
315
  # Compute summary statistics
316
  summary = {
317
  "mean_disposal_rate": np.mean(evaluation_stats["disposal_rates"]),
318
  "std_disposal_rate": np.std(evaluation_stats["disposal_rates"]),
319
  "mean_utilization": np.mean(evaluation_stats["utilization"]),
320
- "mean_hearings_to_disposal": np.mean(evaluation_stats["avg_hearing_to_disposal"])
321
  }
322
-
323
- print(f"Evaluation complete:")
324
  print(f"Mean disposal rate: {summary['mean_disposal_rate']:.1%} ± {summary['std_disposal_rate']:.1%}")
325
  print(f"Mean utilization: {summary['mean_utilization']:.1%}")
326
  print(f"Avg hearings to disposal: {summary['mean_hearings_to_disposal']:.1f}")
327
-
328
- return summary
 
6
 
7
  import numpy as np
8
  from pathlib import Path
9
+ from typing import List, Tuple, Dict, Optional
10
  from datetime import date, timedelta
11
  import random
12
 
13
  from scheduler.data.case_generator import CaseGenerator
 
14
  from scheduler.core.case import Case, CaseStatus
15
+ from scheduler.core.algorithm import SchedulingAlgorithm
16
+ from scheduler.core.courtroom import Courtroom
17
+ from scheduler.core.policy import SchedulerPolicy
18
+ from scheduler.simulation.policies.readiness import ReadinessPolicy
19
+ from scheduler.simulation.allocator import CourtroomAllocator, AllocationStrategy
20
+ from scheduler.control.overrides import Override, OverrideType, JudgePreferences
21
  from .simple_agent import TabularQAgent, CaseState
22
+ from .config import (
23
+ RLTrainingConfig,
24
+ PolicyConfig,
25
+ DEFAULT_RL_TRAINING_CONFIG,
26
+ DEFAULT_POLICY_CONFIG,
27
+ )
28
 
29
 
30
  class RLTrainingEnvironment:
31
  """Training environment for RL agent using court simulation."""
32
+
33
+ def __init__(
34
+ self,
35
+ cases: List[Case],
36
+ start_date: date,
37
+ horizon_days: int = 90,
38
+ rl_config: RLTrainingConfig | None = None,
39
+ policy_config: PolicyConfig | None = None,
40
+ ):
41
  """Initialize training environment.
42
+
43
  Args:
44
  cases: List of cases to simulate
45
  start_date: Simulation start date
46
  horizon_days: Training episode length in days
47
+ rl_config: RL-specific training constraints
48
+ policy_config: Policy knobs for ripeness/gap rules
49
  """
50
  self.cases = cases
51
  self.start_date = start_date
52
  self.horizon_days = horizon_days
53
  self.current_date = start_date
54
  self.episode_rewards = []
55
+ self.rl_config = rl_config or DEFAULT_RL_TRAINING_CONFIG
56
+ self.policy_config = policy_config or DEFAULT_POLICY_CONFIG
57
+
58
+ # Resources mirroring production defaults
59
+ self.courtrooms = [
60
+ Courtroom(
61
+ courtroom_id=i + 1,
62
+ judge_id=f"J{i+1:03d}",
63
+ daily_capacity=self.rl_config.daily_capacity_per_courtroom,
64
+ )
65
+ for i in range(self.rl_config.courtrooms)
66
+ ]
67
+ self.allocator = CourtroomAllocator(
68
+ num_courtrooms=self.rl_config.courtrooms,
69
+ per_courtroom_capacity=self.rl_config.daily_capacity_per_courtroom,
70
+ strategy=AllocationStrategy.LOAD_BALANCED,
71
+ )
72
+ self.policy: SchedulerPolicy = ReadinessPolicy()
73
+ self.algorithm = SchedulingAlgorithm(
74
+ policy=self.policy,
75
+ allocator=self.allocator,
76
+ min_gap_days=self.policy_config.min_gap_days if self.rl_config.enforce_min_gap else 0,
77
+ )
78
+ self.preferences = self._build_preferences()
79
+
80
+ def _build_preferences(self) -> Optional[JudgePreferences]:
81
+ """Synthetic judge preferences for training context."""
82
+ if not self.rl_config.apply_judge_preferences:
83
+ return None
84
+
85
+ capacity_overrides = {room.courtroom_id: room.daily_capacity for room in self.courtrooms}
86
+ return JudgePreferences(
87
+ judge_id="RL-JUDGE",
88
+ capacity_overrides=capacity_overrides,
89
+ case_type_preferences={
90
+ "Monday": ["RSA"],
91
+ "Tuesday": ["CCC"],
92
+ "Wednesday": ["NI ACT"],
93
+ },
94
+ )
95
+
96
  def reset(self) -> List[Case]:
97
  """Reset environment for new training episode.
98
+
99
  Note: In practice, train_agent() generates fresh cases per episode,
100
  so case state doesn't need resetting. This method just resets
101
  environment state (date, rewards).
 
103
  self.current_date = self.start_date
104
  self.episode_rewards = []
105
  return self.cases.copy()
106
+
107
+ def capacity_ratio(self, remaining_slots: int) -> float:
108
+ """Proportion of courtroom capacity still available for the day."""
109
+ total_capacity = self.rl_config.courtrooms * self.rl_config.daily_capacity_per_courtroom
110
+ return max(0.0, min(1.0, remaining_slots / total_capacity)) if total_capacity else 0.0
111
+
112
+ def preference_score(self, case: Case) -> float:
113
+ """Return 1.0 when case_type aligns with day-of-week preference, else 0."""
114
+ if not self.preferences:
115
+ return 0.0
116
+
117
+ day_name = self.current_date.strftime("%A")
118
+ preferred_types = self.preferences.case_type_preferences.get(day_name, [])
119
+ return 1.0 if case.case_type in preferred_types else 0.0
120
+
121
  def step(self, agent_decisions: Dict[str, int]) -> Tuple[List[Case], Dict[str, float], bool]:
122
+ """Execute one day of simulation with agent decisions via SchedulingAlgorithm."""
123
+ rewards: Dict[str, float] = {}
124
+
125
+ # Convert agent schedule actions into priority overrides
126
+ overrides: List[Override] = []
127
+ priority_boost = 1.0
128
+ for case in self.cases:
129
+ if agent_decisions.get(case.case_id) == 1:
130
+ overrides.append(
131
+ Override(
132
+ override_id=f"rl-{case.case_id}-{self.current_date.isoformat()}",
133
+ override_type=OverrideType.PRIORITY,
134
+ case_id=case.case_id,
135
+ judge_id="RL-JUDGE",
136
+ timestamp=self.current_date,
137
+ new_priority=case.get_priority_score() + priority_boost,
138
+ )
139
+ )
140
+ priority_boost += 0.1 # keep relative ordering stable
141
+
142
+ # Run scheduling algorithm (capacity, ripeness, min-gap enforced)
143
+ result = self.algorithm.schedule_day(
144
+ cases=self.cases,
145
+ courtrooms=self.courtrooms,
146
+ current_date=self.current_date,
147
+ overrides=overrides or None,
148
+ preferences=self.preferences,
149
+ )
150
+
151
+ # Flatten scheduled cases
152
+ scheduled_cases = [c for cases in result.scheduled_cases.values() for c in cases]
153
+
154
  # Simulate hearing outcomes for scheduled cases
155
  for case in scheduled_cases:
156
  if case.is_disposed:
157
  continue
158
+
 
159
  outcome = self._simulate_hearing_outcome(case)
160
  was_heard = "heard" in outcome.lower()
 
 
161
  case.record_hearing(self.current_date, was_heard=was_heard, outcome=outcome)
162
+
163
  if was_heard:
 
164
  if outcome in ["FINAL DISPOSAL", "SETTLEMENT", "NA"]:
165
  case.status = CaseStatus.DISPOSED
166
  case.disposal_date = self.current_date
167
  elif outcome != "ADJOURNED":
 
168
  case.current_stage = outcome
169
+
 
 
170
  rewards[case.case_id] = self._compute_reward(case, outcome)
171
+
172
  # Update case ages
173
  for case in self.cases:
174
  case.update_age(self.current_date)
175
+
176
  # Move to next day
177
  self.current_date += timedelta(days=1)
178
  episode_done = (self.current_date - self.start_date).days >= self.horizon_days
179
+
180
  return self.cases, rewards, episode_done
181
+
182
  def _simulate_hearing_outcome(self, case: Case) -> str:
183
  """Simulate hearing outcome based on stage and case characteristics."""
184
  # Simplified outcome simulation
185
  current_stage = case.current_stage
186
+
187
  # Terminal stages - high disposal probability
188
  if current_stage in ["ORDERS / JUDGMENT", "FINAL DISPOSAL"]:
189
  if random.random() < 0.7: # 70% chance of disposal
190
  return "FINAL DISPOSAL"
191
  else:
192
  return "ADJOURNED"
193
+
194
  # Early stages more likely to adjourn
195
  if current_stage in ["PRE-ADMISSION", "ADMISSION"]:
196
  if random.random() < 0.6: # 60% adjournment rate
 
201
  return "ADMISSION"
202
  else:
203
  return "EVIDENCE"
204
+
205
  # Mid-stages
206
  if current_stage in ["EVIDENCE", "ARGUMENTS"]:
207
  if random.random() < 0.4: # 40% adjournment rate
 
211
  return "ARGUMENTS"
212
  else:
213
  return "ORDERS / JUDGMENT"
214
+
215
  # Default progression
216
  return "ARGUMENTS"
217
+
218
  def _compute_reward(self, case: Case, outcome: str) -> float:
219
  """Compute reward based on case and outcome."""
220
  agent = TabularQAgent() # Use for reward computation
221
  return agent.compute_reward(case, was_scheduled=True, hearing_outcome=outcome)
222
 
223
 
224
+ def train_agent(
225
+ agent: TabularQAgent,
226
+ rl_config: RLTrainingConfig = DEFAULT_RL_TRAINING_CONFIG,
227
+ policy_config: PolicyConfig = DEFAULT_POLICY_CONFIG,
228
+ verbose: bool = True,
229
+ ) -> Dict:
230
+ """Train RL agent using episodic simulation with courtroom constraints."""
231
+ config = rl_config or DEFAULT_RL_TRAINING_CONFIG
232
+ policy_cfg = policy_config or DEFAULT_POLICY_CONFIG
233
+
234
+ # Align agent hyperparameters with config
235
+ agent.learning_rate = config.learning_rate
236
+ agent.discount = config.discount_factor
237
+ agent.epsilon = config.initial_epsilon
238
+
 
239
  training_stats = {
240
  "episodes": [],
241
  "total_rewards": [],
242
  "disposal_rates": [],
243
  "states_explored": [],
244
+ "q_updates": [],
245
  }
246
+
247
  if verbose:
248
+ print(f"Training RL agent for {config.episodes} episodes...")
249
+
250
+ for episode in range(config.episodes):
251
  # Generate fresh cases for this episode
252
  start_date = date(2024, 1, 1) + timedelta(days=episode * 10)
253
  end_date = start_date + timedelta(days=30)
254
+
255
+ generator = CaseGenerator(
256
+ start=start_date,
257
+ end=end_date,
258
+ seed=config.training_seed + episode,
259
+ )
260
+ cases = generator.generate(config.cases_per_episode, stage_mix_auto=config.stage_mix_auto)
261
+
262
  # Initialize training environment
263
+ env = RLTrainingEnvironment(
264
+ cases,
265
+ start_date,
266
+ config.episode_length_days,
267
+ rl_config=config,
268
+ policy_config=policy_cfg,
269
+ )
270
+
271
  # Reset environment
272
  episode_cases = env.reset()
273
  episode_reward = 0.0
274
+
275
+ total_capacity = config.courtrooms * config.daily_capacity_per_courtroom
276
+
277
  # Run episode
278
+ for _ in range(config.episode_length_days):
279
  # Get eligible cases (not disposed, basic filtering)
280
  eligible_cases = [c for c in episode_cases if not c.is_disposed]
281
  if not eligible_cases:
282
  break
283
+
284
  # Agent makes decisions for each case
285
  agent_decisions = {}
286
  case_states = {}
287
+
288
+ daily_cap = config.max_daily_allocations or total_capacity
289
+ if not config.cap_daily_allocations:
290
+ daily_cap = len(eligible_cases)
291
+ remaining_slots = min(daily_cap, total_capacity) if config.cap_daily_allocations else daily_cap
292
+
293
+ for case in eligible_cases[:daily_cap]:
294
+ cap_ratio = env.capacity_ratio(remaining_slots if remaining_slots else total_capacity)
295
+ pref_score = env.preference_score(case)
296
+ state = agent.extract_state(
297
+ case,
298
+ env.current_date,
299
+ capacity_ratio=cap_ratio,
300
+ min_gap_days=policy_cfg.min_gap_days if config.enforce_min_gap else 0,
301
+ preference_score=pref_score,
302
+ )
303
  action = agent.get_action(state, training=True)
304
+
305
+ if config.cap_daily_allocations and action == 1 and remaining_slots <= 0:
306
+ action = 0
307
+ elif action == 1 and config.cap_daily_allocations:
308
+ remaining_slots = max(0, remaining_slots - 1)
309
+
310
  agent_decisions[case.case_id] = action
311
  case_states[case.case_id] = state
312
+
313
  # Environment step
314
+ _, rewards, done = env.step(agent_decisions)
315
+
316
  # Update Q-values based on rewards
317
  for case_id, reward in rewards.items():
318
  if case_id in case_states:
319
  state = case_states[case_id]
320
+ action = agent_decisions.get(case_id, 0)
321
+
 
322
  agent.update_q_value(state, action, reward)
323
  episode_reward += reward
324
+
325
  if done:
326
  break
327
+
328
  # Compute episode statistics
329
  disposed_count = sum(1 for c in episode_cases if c.is_disposed)
330
  disposal_rate = disposed_count / len(episode_cases) if episode_cases else 0.0
331
+
332
  # Record statistics
333
  training_stats["episodes"].append(episode)
334
  training_stats["total_rewards"].append(episode_reward)
335
  training_stats["disposal_rates"].append(disposal_rate)
336
  training_stats["states_explored"].append(len(agent.states_visited))
337
  training_stats["q_updates"].append(agent.total_updates)
338
+
339
  # Decay exploration
340
+ agent.epsilon = max(config.min_epsilon, agent.epsilon * config.epsilon_decay)
341
+
 
342
  if verbose and (episode + 1) % 10 == 0:
343
+ print(
344
+ f"Episode {episode + 1}/{config.episodes}: "
345
+ f"Reward={episode_reward:.1f}, "
346
+ f"Disposal={disposal_rate:.1%}, "
347
+ f"States={len(agent.states_visited)}, "
348
+ f"Epsilon={agent.epsilon:.3f}"
349
+ )
350
+
351
  if verbose:
352
  final_stats = agent.get_stats()
353
  print(f"\nTraining complete!")
354
  print(f"States explored: {final_stats['states_visited']}")
355
  print(f"Q-table size: {final_stats['q_table_size']}")
356
  print(f"Total updates: {final_stats['total_updates']}")
357
+
358
  return training_stats
359
 
360
 
361
+ def evaluate_agent(
362
+ agent: TabularQAgent,
363
+ test_cases: List[Case],
364
+ episodes: Optional[int] = None,
365
+ episode_length: Optional[int] = None,
366
+ rl_config: RLTrainingConfig = DEFAULT_RL_TRAINING_CONFIG,
367
+ policy_config: PolicyConfig = DEFAULT_POLICY_CONFIG,
368
+ ) -> Dict:
369
+ """Evaluate trained agent performance."""
 
 
 
 
370
  # Set agent to evaluation mode (no exploration)
371
  original_epsilon = agent.epsilon
372
  agent.epsilon = 0.0
373
+
374
+ config = rl_config or DEFAULT_RL_TRAINING_CONFIG
375
+ policy_cfg = policy_config or DEFAULT_POLICY_CONFIG
376
+
377
  evaluation_stats = {
378
  "disposal_rates": [],
379
  "total_hearings": [],
380
  "avg_hearing_to_disposal": [],
381
+ "utilization": [],
382
  }
383
+
384
+ eval_episodes = episodes if episodes is not None else 10
385
+ eval_length = episode_length if episode_length is not None else config.episode_length_days
386
+
387
+ print(f"Evaluating agent on {eval_episodes} test episodes...")
388
+
389
+ total_capacity = config.courtrooms * config.daily_capacity_per_courtroom
390
+
391
+ for episode in range(eval_episodes):
392
  start_date = date(2024, 6, 1) + timedelta(days=episode * 10)
393
+ env = RLTrainingEnvironment(
394
+ test_cases.copy(),
395
+ start_date,
396
+ eval_length,
397
+ rl_config=config,
398
+ policy_config=policy_cfg,
399
+ )
400
+
401
  episode_cases = env.reset()
402
  total_hearings = 0
403
+
404
  # Run evaluation episode
405
+ for _ in range(eval_length):
406
  eligible_cases = [c for c in episode_cases if not c.is_disposed]
407
  if not eligible_cases:
408
  break
409
+
410
+ daily_cap = config.max_daily_allocations or total_capacity
411
+ remaining_slots = min(daily_cap, total_capacity) if config.cap_daily_allocations else len(eligible_cases)
412
+
413
  # Agent makes decisions (no exploration)
414
  agent_decisions = {}
415
+ for case in eligible_cases[:daily_cap]:
416
+ cap_ratio = env.capacity_ratio(remaining_slots if remaining_slots else total_capacity)
417
+ pref_score = env.preference_score(case)
418
+ state = agent.extract_state(
419
+ case,
420
+ env.current_date,
421
+ capacity_ratio=cap_ratio,
422
+ min_gap_days=policy_cfg.min_gap_days if config.enforce_min_gap else 0,
423
+ preference_score=pref_score,
424
+ )
425
  action = agent.get_action(state, training=False)
426
+ if config.cap_daily_allocations and action == 1 and remaining_slots <= 0:
427
+ action = 0
428
+ elif action == 1 and config.cap_daily_allocations:
429
+ remaining_slots = max(0, remaining_slots - 1)
430
+
431
  agent_decisions[case.case_id] = action
432
+
433
  # Environment step
434
+ _, rewards, done = env.step(agent_decisions)
435
  total_hearings += len([r for r in rewards.values() if r != 0])
436
+
437
  if done:
438
  break
439
+
440
  # Compute metrics
441
  disposed_count = sum(1 for c in episode_cases if c.is_disposed)
442
  disposal_rate = disposed_count / len(episode_cases)
443
+
444
  disposed_cases = [c for c in episode_cases if c.is_disposed]
445
  avg_hearings = np.mean([c.hearing_count for c in disposed_cases]) if disposed_cases else 0
446
+
447
  evaluation_stats["disposal_rates"].append(disposal_rate)
448
  evaluation_stats["total_hearings"].append(total_hearings)
449
  evaluation_stats["avg_hearing_to_disposal"].append(avg_hearings)
450
+ evaluation_stats["utilization"].append(total_hearings / (eval_length * total_capacity))
451
+
452
  # Restore original epsilon
453
  agent.epsilon = original_epsilon
454
+
455
  # Compute summary statistics
456
  summary = {
457
  "mean_disposal_rate": np.mean(evaluation_stats["disposal_rates"]),
458
  "std_disposal_rate": np.std(evaluation_stats["disposal_rates"]),
459
  "mean_utilization": np.mean(evaluation_stats["utilization"]),
460
+ "mean_hearings_to_disposal": np.mean(evaluation_stats["avg_hearing_to_disposal"]),
461
  }
462
+
463
+ print("Evaluation complete:")
464
  print(f"Mean disposal rate: {summary['mean_disposal_rate']:.1%} ± {summary['std_disposal_rate']:.1%}")
465
  print(f"Mean utilization: {summary['mean_utilization']:.1%}")
466
  print(f"Avg hearings to disposal: {summary['mean_hearings_to_disposal']:.1f}")
467
+
468
+ return summary