RoyAalekh commited on
Commit
77677ad
·
1 Parent(s): 8d2e8fa

Add auditing metadata to RL scheduler outputs

Browse files
court_scheduler_rl.py CHANGED
@@ -97,13 +97,21 @@ class InteractivePipeline:
97
  console.print("\n[bold cyan]Step 1/7: EDA & Parameter Extraction[/bold cyan]")
98
 
99
  # Check if EDA was run recently
 
 
100
  param_dir = Path("reports/figures").glob("v0.4.0_*/params")
101
- recent_params = any(p.exists() and
102
  (datetime.now() - datetime.fromtimestamp(p.stat().st_mtime)).days < 1
103
  for p in param_dir)
104
-
105
  if recent_params and not Confirm.ask("EDA parameters found. Regenerate?", default=False):
106
  console.print(" [green]OK[/green] Using existing EDA parameters")
 
 
 
 
 
 
107
  return
108
 
109
  with Progress(
@@ -127,10 +135,16 @@ class InteractivePipeline:
127
  run_load_and_clean()
128
  run_exploration()
129
  run_parameter_export()
130
-
131
  progress.update(task, completed=True)
132
-
133
  console.print(" [green]OK[/green] EDA pipeline complete")
 
 
 
 
 
 
134
 
135
  def _step_2_data_generation(self):
136
  """Step 2: Generate Training Data"""
@@ -169,7 +183,10 @@ class InteractivePipeline:
169
  console.print(f" Episodes: {self.config.rl_training.episodes}, Learning Rate: {self.config.rl_training.learning_rate}")
170
 
171
  model_file = self.output.trained_model_file
172
-
 
 
 
173
  with Progress(
174
  SpinnerColumn(),
175
  TextColumn("[progress.description]{task.description}"),
@@ -201,12 +218,63 @@ class InteractivePipeline:
201
  episode_length=rl_cfg.episode_length_days,
202
  verbose=False # Disable internal printing
203
  )
204
-
205
  progress.update(training_task, completed=rl_cfg.episodes)
206
-
207
  # Save trained agent
208
  agent.save(model_file)
209
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  # Create symlink in models/ for backwards compatibility
211
  self.output.create_model_symlink()
212
 
@@ -270,18 +338,38 @@ class InteractivePipeline:
270
 
271
  sim = CourtSim(cfg, policy_cases)
272
  result = sim.run()
273
-
274
  progress.update(task, completed=100)
275
-
276
  results[policy] = {
277
  'result': result,
278
  'cases': policy_cases, # Use the deep-copied cases for this simulation
279
  'sim': sim,
280
  'dir': policy_dir
281
  }
282
-
283
  console.print(f" [green]OK[/green] {result.disposals:,} disposals ({result.disposals/len(cases):.1%})")
284
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  self.sim_results = results
286
  console.print(f" [green]OK[/green] All simulations complete")
287
 
 
97
  console.print("\n[bold cyan]Step 1/7: EDA & Parameter Extraction[/bold cyan]")
98
 
99
  # Check if EDA was run recently
100
+ from src import eda_config
101
+
102
  param_dir = Path("reports/figures").glob("v0.4.0_*/params")
103
+ recent_params = any(p.exists() and
104
  (datetime.now() - datetime.fromtimestamp(p.stat().st_mtime)).days < 1
105
  for p in param_dir)
106
+
107
  if recent_params and not Confirm.ask("EDA parameters found. Regenerate?", default=False):
108
  console.print(" [green]OK[/green] Using existing EDA parameters")
109
+ self.output.record_eda_metadata(
110
+ version=eda_config.VERSION,
111
+ used_cached=True,
112
+ params_path=self.output.eda_params,
113
+ figures_path=self.output.eda_figures,
114
+ )
115
  return
116
 
117
  with Progress(
 
135
  run_load_and_clean()
136
  run_exploration()
137
  run_parameter_export()
138
+
139
  progress.update(task, completed=True)
140
+
141
  console.print(" [green]OK[/green] EDA pipeline complete")
142
+ self.output.record_eda_metadata(
143
+ version=eda_config.VERSION,
144
+ used_cached=False,
145
+ params_path=self.output.eda_params,
146
+ figures_path=self.output.eda_figures,
147
+ )
148
 
149
  def _step_2_data_generation(self):
150
  """Step 2: Generate Training Data"""
 
183
  console.print(f" Episodes: {self.config.rl_training.episodes}, Learning Rate: {self.config.rl_training.learning_rate}")
184
 
185
  model_file = self.output.trained_model_file
186
+
187
+ def _safe_mean(values: List[float]) -> float:
188
+ return sum(values) / len(values) if values else 0.0
189
+
190
  with Progress(
191
  SpinnerColumn(),
192
  TextColumn("[progress.description]{task.description}"),
 
218
  episode_length=rl_cfg.episode_length_days,
219
  verbose=False # Disable internal printing
220
  )
221
+
222
  progress.update(training_task, completed=rl_cfg.episodes)
223
+
224
  # Save trained agent
225
  agent.save(model_file)
226
+
227
+ # Persist training stats for downstream consumers
228
+ self.output.save_training_stats(training_stats)
229
+
230
+ # Run a lightweight evaluation sweep for summary metrics
231
+ evaluation_stats = None
232
+ try:
233
+ from rl.training import evaluate_agent
234
+ from scheduler.data.case_generator import CaseGenerator
235
+
236
+ eval_gen = CaseGenerator(
237
+ start=date.today(),
238
+ end=date.today() + timedelta(days=60),
239
+ seed=self.config.seed + 99,
240
+ )
241
+ eval_cases = eval_gen.generate(min(rl_cfg.cases_per_episode, 500), stage_mix_auto=True)
242
+ evaluation_stats = evaluate_agent(
243
+ agent=agent,
244
+ test_cases=eval_cases,
245
+ episodes=5,
246
+ episode_length=rl_cfg.episode_length_days,
247
+ )
248
+ self.output.save_evaluation_stats(evaluation_stats)
249
+ except Exception as eval_err:
250
+ console.print(f" [yellow]WARNING[/yellow] Evaluation skipped: {eval_err}")
251
+
252
+ training_summary = {
253
+ "episodes": rl_cfg.episodes,
254
+ "cases_per_episode": rl_cfg.cases_per_episode,
255
+ "episode_length_days": rl_cfg.episode_length_days,
256
+ "learning_rate": rl_cfg.learning_rate,
257
+ "epsilon": {
258
+ "initial": rl_cfg.initial_epsilon,
259
+ "final": agent.epsilon,
260
+ },
261
+ "reward": {
262
+ "mean": _safe_mean(training_stats.get("total_rewards", [])),
263
+ "final": training_stats.get("total_rewards", [0])[-1] if training_stats.get("total_rewards") else 0.0,
264
+ },
265
+ "disposal_rate": {
266
+ "mean": _safe_mean(training_stats.get("disposal_rates", [])),
267
+ "final": training_stats.get("disposal_rates", [0])[-1] if training_stats.get("disposal_rates") else 0.0,
268
+ },
269
+ "states_explored_final": training_stats.get("states_explored", [len(agent.q_table)])[-1]
270
+ if training_stats.get("states_explored")
271
+ else len(agent.q_table),
272
+ "q_table_size": len(agent.q_table),
273
+ "total_updates": getattr(agent, "total_updates", 0),
274
+ }
275
+
276
+ self.output.record_training_summary(training_summary, evaluation_stats)
277
+
278
  # Create symlink in models/ for backwards compatibility
279
  self.output.create_model_symlink()
280
 
 
338
 
339
  sim = CourtSim(cfg, policy_cases)
340
  result = sim.run()
341
+
342
  progress.update(task, completed=100)
343
+
344
  results[policy] = {
345
  'result': result,
346
  'cases': policy_cases, # Use the deep-copied cases for this simulation
347
  'sim': sim,
348
  'dir': policy_dir
349
  }
350
+
351
  console.print(f" [green]OK[/green] {result.disposals:,} disposals ({result.disposals/len(cases):.1%})")
352
+
353
+ allocator_stats = sim.allocator.get_utilization_stats()
354
+ backlog = sum(1 for c in policy_cases if not c.is_disposed)
355
+
356
+ kpis = {
357
+ "policy": policy,
358
+ "disposals": result.disposals,
359
+ "disposal_rate": result.disposals / len(policy_cases),
360
+ "utilization": result.utilization,
361
+ "hearings_total": result.hearings_total,
362
+ "hearings_heard": result.hearings_heard,
363
+ "hearings_adjourned": result.hearings_adjourned,
364
+ "backlog": backlog,
365
+ "backlog_rate": backlog / len(policy_cases) if policy_cases else 0,
366
+ "fairness_gini": allocator_stats.get("load_balance_gini"),
367
+ "avg_daily_load": allocator_stats.get("avg_daily_load"),
368
+ "capacity_rejections": allocator_stats.get("capacity_rejections"),
369
+ }
370
+
371
+ self.output.record_simulation_kpis(policy, kpis)
372
+
373
  self.sim_results = results
374
  console.print(f" [green]OK[/green] All simulations complete")
375
 
scheduler/utils/output_manager.py CHANGED
@@ -6,7 +6,7 @@ No scattered files, no duplicate saves, single source of truth per run.
6
 
7
  from pathlib import Path
8
  from datetime import datetime
9
- from typing import Optional
10
  import json
11
  from dataclasses import asdict
12
 
@@ -30,7 +30,8 @@ class OutputManager:
30
  base_dir: Base directory for all outputs (default: outputs/runs)
31
  """
32
  self.run_id = run_id or f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
33
-
 
34
  # Base paths
35
  project_root = Path(__file__).parent.parent.parent
36
  self.base_dir = base_dir or (project_root / "outputs" / "runs")
@@ -49,6 +50,9 @@ class OutputManager:
49
 
50
  # Reports subdirectories
51
  self.visualizations_dir = self.reports_dir / "visualizations"
 
 
 
52
 
53
  def create_structure(self):
54
  """Create all output directories."""
@@ -64,10 +68,18 @@ class OutputManager:
64
  self.visualizations_dir,
65
  ]:
66
  dir_path.mkdir(parents=True, exist_ok=True)
67
-
 
 
 
 
 
 
 
 
68
  def save_config(self, config):
69
  """Save pipeline configuration to run directory.
70
-
71
  Args:
72
  config: PipelineConfig or any dataclass
73
  """
@@ -76,6 +88,45 @@ class OutputManager:
76
  # Handle nested dataclasses (like rl_training)
77
  config_dict = asdict(config) if hasattr(config, '__dataclass_fields__') else config
78
  json.dump(config_dict, f, indent=2, default=str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def get_policy_dir(self, policy_name: str) -> Path:
81
  """Get simulation directory for a specific policy.
@@ -102,7 +153,37 @@ class OutputManager:
102
  cause_list_dir = self.get_policy_dir(policy_name) / "cause_lists"
103
  cause_list_dir.mkdir(parents=True, exist_ok=True)
104
  return cause_list_dir
105
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  @property
107
  def training_cases_file(self) -> Path:
108
  """Path to generated training cases CSV."""
@@ -152,9 +233,38 @@ class OutputManager:
152
  # Fallback: copy file if symlinks not supported (Windows without dev mode)
153
  import shutil
154
  shutil.copy2(target, symlink_path)
155
-
156
  def __str__(self) -> str:
157
  return f"OutputManager(run_id='{self.run_id}', run_dir='{self.run_dir}')"
158
-
159
  def __repr__(self) -> str:
160
  return self.__str__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from pathlib import Path
8
  from datetime import datetime
9
+ from typing import Optional, Dict, Any
10
  import json
11
  from dataclasses import asdict
12
 
 
30
  base_dir: Base directory for all outputs (default: outputs/runs)
31
  """
32
  self.run_id = run_id or f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
33
+ self.created_at = datetime.now().isoformat()
34
+
35
  # Base paths
36
  project_root = Path(__file__).parent.parent.parent
37
  self.base_dir = base_dir or (project_root / "outputs" / "runs")
 
50
 
51
  # Reports subdirectories
52
  self.visualizations_dir = self.reports_dir / "visualizations"
53
+
54
+ # Metadata paths
55
+ self.run_record_file = self.run_dir / "run_record.json"
56
 
57
  def create_structure(self):
58
  """Create all output directories."""
 
68
  self.visualizations_dir,
69
  ]:
70
  dir_path.mkdir(parents=True, exist_ok=True)
71
+
72
+ # Initialize run record with creation metadata if missing
73
+ if not self.run_record_file.exists():
74
+ self._update_run_record("run", {
75
+ "run_id": self.run_id,
76
+ "created_at": self.created_at,
77
+ "base_dir": str(self.run_dir),
78
+ })
79
+
80
  def save_config(self, config):
81
  """Save pipeline configuration to run directory.
82
+
83
  Args:
84
  config: PipelineConfig or any dataclass
85
  """
 
88
  # Handle nested dataclasses (like rl_training)
89
  config_dict = asdict(config) if hasattr(config, '__dataclass_fields__') else config
90
  json.dump(config_dict, f, indent=2, default=str)
91
+
92
+ self._update_run_record("config", {
93
+ "path": str(config_path),
94
+ "timestamp": datetime.now().isoformat(),
95
+ })
96
+
97
+ def save_training_stats(self, training_stats: Dict[str, Any]):
98
+ """Persist raw training statistics for auditing and dashboards."""
99
+
100
+ self.training_dir.mkdir(parents=True, exist_ok=True)
101
+ with open(self.training_stats_file, "w", encoding="utf-8") as f:
102
+ json.dump(training_stats, f, indent=2, default=str)
103
+
104
+ def save_evaluation_stats(self, evaluation_stats: Dict[str, Any]):
105
+ """Persist evaluation metrics for downstream analysis."""
106
+
107
+ eval_path = self.training_dir / "evaluation.json"
108
+ with open(eval_path, "w", encoding="utf-8") as f:
109
+ json.dump(evaluation_stats, f, indent=2, default=str)
110
+
111
+ self._update_run_record("evaluation", {
112
+ "path": str(eval_path),
113
+ "timestamp": datetime.now().isoformat(),
114
+ })
115
+
116
+ def record_training_summary(self, summary: Dict[str, Any], evaluation: Optional[Dict[str, Any]] = None):
117
+ """Save aggregated training/evaluation summary for dashboards."""
118
+
119
+ summary_path = self.training_dir / "summary.json"
120
+ payload = {
121
+ "summary": summary,
122
+ "evaluation": evaluation,
123
+ "updated_at": datetime.now().isoformat(),
124
+ }
125
+
126
+ with open(summary_path, "w", encoding="utf-8") as f:
127
+ json.dump(payload, f, indent=2, default=str)
128
+
129
+ self._update_run_record("training", payload)
130
 
131
  def get_policy_dir(self, policy_name: str) -> Path:
132
  """Get simulation directory for a specific policy.
 
153
  cause_list_dir = self.get_policy_dir(policy_name) / "cause_lists"
154
  cause_list_dir.mkdir(parents=True, exist_ok=True)
155
  return cause_list_dir
156
+
157
+ def record_eda_metadata(self, version: str, used_cached: bool, params_path: Path, figures_path: Path):
158
+ """Record EDA version/timestamp for auditability."""
159
+
160
+ payload = {
161
+ "version": version,
162
+ "timestamp": datetime.now().isoformat(),
163
+ "used_cached": used_cached,
164
+ "params_path": str(params_path),
165
+ "figures_path": str(figures_path),
166
+ }
167
+
168
+ self._update_run_record("eda", payload)
169
+
170
+ def record_simulation_kpis(self, policy: str, kpis: Dict[str, Any]):
171
+ """Persist simulation KPIs per policy for dashboards."""
172
+
173
+ policy_dir = self.get_policy_dir(policy)
174
+ metrics_path = policy_dir / "metrics.json"
175
+ with open(metrics_path, "w", encoding="utf-8") as f:
176
+ json.dump(kpis, f, indent=2, default=str)
177
+
178
+ record = self._load_run_record()
179
+ simulation_section = record.get("simulation", {})
180
+ simulation_section[policy] = kpis
181
+ record["simulation"] = simulation_section
182
+ record["updated_at"] = datetime.now().isoformat()
183
+
184
+ with open(self.run_record_file, "w", encoding="utf-8") as f:
185
+ json.dump(record, f, indent=2, default=str)
186
+
187
  @property
188
  def training_cases_file(self) -> Path:
189
  """Path to generated training cases CSV."""
 
233
  # Fallback: copy file if symlinks not supported (Windows without dev mode)
234
  import shutil
235
  shutil.copy2(target, symlink_path)
236
+
237
  def __str__(self) -> str:
238
  return f"OutputManager(run_id='{self.run_id}', run_dir='{self.run_dir}')"
239
+
240
  def __repr__(self) -> str:
241
  return self.__str__()
242
+
243
+ # ------------------------------------------------------------------
244
+ # Internal helpers
245
+ # ------------------------------------------------------------------
246
+ def _load_run_record(self) -> Dict[str, Any]:
247
+ """Load run record JSON, providing defaults if missing."""
248
+
249
+ if self.run_record_file.exists():
250
+ try:
251
+ with open(self.run_record_file, "r", encoding="utf-8") as f:
252
+ return json.load(f)
253
+ except json.JSONDecodeError:
254
+ pass
255
+
256
+ return {
257
+ "run_id": self.run_id,
258
+ "created_at": self.created_at,
259
+ }
260
+
261
+ def _update_run_record(self, section: str, payload: Dict[str, Any]):
262
+ """Upsert a section within the consolidated run record."""
263
+
264
+ record = self._load_run_record()
265
+ record.setdefault("sections", {})
266
+ record["sections"][section] = payload
267
+ record["updated_at"] = datetime.now().isoformat()
268
+
269
+ with open(self.run_record_file, "w", encoding="utf-8") as f:
270
+ json.dump(record, f, indent=2, default=str)