Spaces:
Sleeping
feat: Implement OutputManager for clean output directory structure
Browse files- Created OutputManager class to centralize all pipeline output paths
- Single run directory per execution with timestamp-based ID
- Clear hierarchy: eda/, training/, simulation/, reports/
- Removed duplicate model saving, created symlink system
- Refactored EDA modules to use OutputManager paths
- Fixed scattered output files across data/, reports/, models/
- Config saved at run root for reproducibility
Changes:
- New: scheduler/utils/output_manager.py - centralized output management
- New: rl/config.py - structured RL configuration with presets
- New: docs/CONFIGURATION.md - 5-layer config architecture docs
- New: docs/OUTPUT_REFACTORING.md - implementation status tracker
- Modified: court_scheduler_rl.py - integrated OutputManager
- Modified: src/eda_*.py - dynamic output path configuration
- Modified: PipelineConfig - removed output_dir field
Benefits:
- No scattered files or duplicate saves
- Single source of truth per run
- Easy cleanup and archival
- Reproducible runs via saved config
- Clear separation of concerns
Test: Quick demo pipeline runs end-to-end successfully
Result: outputs/runs/run_TIMESTAMP/ with complete artifacts
- court_scheduler_rl.py +97 -74
- docs/CONFIGURATION.md +194 -0
- docs/OUTPUT_REFACTORING.md +88 -0
- rl/config.py +94 -0
- scheduler/utils/output_manager.py +160 -0
- src/eda_config.py +82 -19
- src/eda_exploration.py +26 -41
- src/eda_load_clean.py +6 -6
- src/eda_parameters.py +18 -18
|
@@ -13,7 +13,7 @@ from datetime import date, datetime, timedelta
|
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Dict, Any, Optional, List
|
| 15 |
import argparse
|
| 16 |
-
from dataclasses import dataclass, asdict
|
| 17 |
|
| 18 |
import typer
|
| 19 |
from rich.console import Console
|
|
@@ -38,36 +38,37 @@ class PipelineConfig:
|
|
| 38 |
stage_mix: str = "auto"
|
| 39 |
seed: int = 42
|
| 40 |
|
| 41 |
-
# RL Training
|
| 42 |
-
|
| 43 |
-
cases_per_episode: int = 1000
|
| 44 |
-
episode_length: int = 45
|
| 45 |
-
learning_rate: float = 0.15
|
| 46 |
-
initial_epsilon: float = 0.4
|
| 47 |
-
epsilon_decay: float = 0.99
|
| 48 |
-
min_epsilon: float = 0.05
|
| 49 |
|
| 50 |
# Simulation
|
| 51 |
sim_days: int = 730 # 2 years
|
| 52 |
sim_start_date: Optional[str] = None
|
| 53 |
policies: List[str] = None
|
| 54 |
|
| 55 |
-
# Output
|
| 56 |
-
output_dir: str = "data/hackathon_run"
|
| 57 |
generate_cause_lists: bool = True
|
| 58 |
generate_visualizations: bool = True
|
| 59 |
|
| 60 |
def __post_init__(self):
|
| 61 |
if self.policies is None:
|
| 62 |
self.policies = ["readiness", "rl"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
class InteractivePipeline:
|
| 65 |
"""Interactive pipeline orchestrator"""
|
| 66 |
|
| 67 |
-
def __init__(self, config: PipelineConfig):
|
| 68 |
self.config = config
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
def run(self):
|
| 73 |
"""Execute complete pipeline"""
|
|
@@ -108,10 +109,17 @@ class InteractivePipeline:
|
|
| 108 |
with Progress(
|
| 109 |
SpinnerColumn(),
|
| 110 |
TextColumn("[progress.description]{task.description}"),
|
| 111 |
-
console=console
|
| 112 |
-
) as progress:
|
| 113 |
task = progress.add_task("Running EDA pipeline...", total=None)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
from src.eda_load_clean import run_load_and_clean
|
| 116 |
from src.eda_exploration import run_exploration
|
| 117 |
from src.eda_parameters import run_parameter_export
|
|
@@ -129,14 +137,13 @@ class InteractivePipeline:
|
|
| 129 |
console.print(f"\n[bold cyan]Step 2/7: Data Generation[/bold cyan]")
|
| 130 |
console.print(f" Generating {self.config.n_cases:,} cases ({self.config.start_date} to {self.config.end_date})")
|
| 131 |
|
| 132 |
-
cases_file = self.
|
| 133 |
|
| 134 |
with Progress(
|
| 135 |
SpinnerColumn(),
|
| 136 |
TextColumn("[progress.description]{task.description}"),
|
| 137 |
BarColumn(),
|
| 138 |
-
console=console
|
| 139 |
-
) as progress:
|
| 140 |
task = progress.add_task("Generating cases...", total=100)
|
| 141 |
|
| 142 |
from datetime import date as date_cls
|
|
@@ -159,55 +166,56 @@ class InteractivePipeline:
|
|
| 159 |
def _step_3_rl_training(self):
|
| 160 |
"""Step 3: RL Agent Training"""
|
| 161 |
console.print(f"\n[bold cyan]Step 3/7: RL Training[/bold cyan]")
|
| 162 |
-
console.print(f" Episodes: {self.config.episodes}, Learning Rate: {self.config.learning_rate}")
|
| 163 |
|
| 164 |
-
model_file = self.
|
| 165 |
|
| 166 |
with Progress(
|
| 167 |
SpinnerColumn(),
|
| 168 |
TextColumn("[progress.description]{task.description}"),
|
| 169 |
BarColumn(),
|
| 170 |
TimeElapsedColumn(),
|
| 171 |
-
console=console
|
| 172 |
-
|
| 173 |
-
training_task = progress.add_task("Training RL agent...", total=self.config.episodes)
|
| 174 |
|
| 175 |
# Import training components
|
| 176 |
from rl.training import train_agent
|
| 177 |
from rl.simple_agent import TabularQAgent
|
| 178 |
import pickle
|
| 179 |
|
| 180 |
-
# Initialize agent
|
|
|
|
| 181 |
agent = TabularQAgent(
|
| 182 |
-
learning_rate=
|
| 183 |
-
epsilon=
|
| 184 |
-
discount=
|
| 185 |
)
|
| 186 |
|
| 187 |
# Training with progress updates
|
| 188 |
# Note: train_agent handles its own progress internally
|
|
|
|
| 189 |
training_stats = train_agent(
|
| 190 |
agent=agent,
|
| 191 |
-
episodes=
|
| 192 |
-
cases_per_episode=
|
| 193 |
-
episode_length=
|
| 194 |
verbose=False # Disable internal printing
|
| 195 |
)
|
| 196 |
|
| 197 |
-
progress.update(training_task, completed=
|
| 198 |
|
| 199 |
# Save trained agent
|
| 200 |
agent.save(model_file)
|
| 201 |
|
| 202 |
-
#
|
| 203 |
-
|
| 204 |
-
models_dir.mkdir(exist_ok=True)
|
| 205 |
-
standard_model_path = models_dir / "trained_rl_agent.pkl"
|
| 206 |
-
agent.save(standard_model_path)
|
| 207 |
|
| 208 |
console.print(f" [green]OK[/green] Training complete -> {model_file}")
|
| 209 |
-
console.print(f" [green]OK[/green]
|
| 210 |
console.print(f" [green]OK[/green] Final epsilon: {agent.epsilon:.4f}, States explored: {len(agent.q_table)}")
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
def _step_4_simulation(self):
|
| 213 |
"""Step 4: 2-Year Simulation"""
|
|
@@ -215,7 +223,7 @@ class InteractivePipeline:
|
|
| 215 |
console.print(f" Duration: {self.config.sim_days} days ({self.config.sim_days/365:.1f} years)")
|
| 216 |
|
| 217 |
# Load cases
|
| 218 |
-
cases_file = self.
|
| 219 |
from scheduler.data.case_generator import CaseGenerator
|
| 220 |
cases = CaseGenerator.from_csv(cases_file)
|
| 221 |
|
|
@@ -227,36 +235,47 @@ class InteractivePipeline:
|
|
| 227 |
for policy in self.config.policies:
|
| 228 |
console.print(f"\n Running {policy} policy simulation...")
|
| 229 |
|
| 230 |
-
policy_dir = self.
|
| 231 |
policy_dir.mkdir(exist_ok=True)
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
with Progress(
|
| 234 |
SpinnerColumn(),
|
| 235 |
TextColumn(f"[progress.description]Simulating {policy}..."),
|
| 236 |
BarColumn(),
|
| 237 |
-
console=console
|
| 238 |
-
) as progress:
|
| 239 |
task = progress.add_task("Simulating...", total=100)
|
| 240 |
|
| 241 |
from scheduler.simulation.engine import CourtSim, CourtSimConfig
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
-
|
|
|
|
|
|
|
| 253 |
result = sim.run()
|
| 254 |
|
| 255 |
progress.update(task, completed=100)
|
| 256 |
|
| 257 |
results[policy] = {
|
| 258 |
'result': result,
|
| 259 |
-
'cases': cases
|
| 260 |
'sim': sim,
|
| 261 |
'dir': policy_dir
|
| 262 |
}
|
|
@@ -280,8 +299,7 @@ class InteractivePipeline:
|
|
| 280 |
with Progress(
|
| 281 |
SpinnerColumn(),
|
| 282 |
TextColumn("[progress.description]{task.description}"),
|
| 283 |
-
console=console
|
| 284 |
-
) as progress:
|
| 285 |
task = progress.add_task("Generating cause lists...", total=None)
|
| 286 |
|
| 287 |
from scheduler.output.cause_list import CauseListGenerator
|
|
@@ -305,8 +323,7 @@ class InteractivePipeline:
|
|
| 305 |
with Progress(
|
| 306 |
SpinnerColumn(),
|
| 307 |
TextColumn("[progress.description]{task.description}"),
|
| 308 |
-
console=console
|
| 309 |
-
) as progress:
|
| 310 |
task = progress.add_task("Analyzing results...", total=None)
|
| 311 |
|
| 312 |
# Generate comparison report
|
|
@@ -327,7 +344,7 @@ class InteractivePipeline:
|
|
| 327 |
summary = self._generate_executive_summary()
|
| 328 |
|
| 329 |
# Save summary
|
| 330 |
-
summary_file = self.
|
| 331 |
with open(summary_file, 'w') as f:
|
| 332 |
f.write(summary)
|
| 333 |
|
|
@@ -370,17 +387,17 @@ class InteractivePipeline:
|
|
| 370 |
|
| 371 |
console.print(Panel.fit(
|
| 372 |
f"[bold green]Pipeline Complete![/bold green]\n\n"
|
| 373 |
-
f"Results: {self.
|
| 374 |
f"Executive Summary: {summary_file}\n"
|
| 375 |
-
f"Visualizations: {self.
|
| 376 |
-
f"Cause Lists: {self.
|
| 377 |
f"[yellow]Ready for hackathon submission![/yellow]",
|
| 378 |
box=box.DOUBLE_EDGE
|
| 379 |
))
|
| 380 |
|
| 381 |
def _generate_comparison_report(self):
|
| 382 |
"""Generate detailed comparison report"""
|
| 383 |
-
report_file = self.
|
| 384 |
|
| 385 |
with open(report_file, 'w') as f:
|
| 386 |
f.write("# Court Scheduling System - Performance Comparison\n\n")
|
|
@@ -389,7 +406,9 @@ class InteractivePipeline:
|
|
| 389 |
f.write("## Configuration\n\n")
|
| 390 |
f.write(f"- Training Cases: {self.config.n_cases:,}\n")
|
| 391 |
f.write(f"- Simulation Period: {self.config.sim_days} days ({self.config.sim_days/365:.1f} years)\n")
|
| 392 |
-
f.write(f"- RL Episodes: {self.config.episodes}\n")
|
|
|
|
|
|
|
| 393 |
f.write(f"- Policies Compared: {', '.join(self.config.policies)}\n\n")
|
| 394 |
|
| 395 |
f.write("## Results Summary\n\n")
|
|
@@ -406,7 +425,7 @@ class InteractivePipeline:
|
|
| 406 |
|
| 407 |
def _generate_visualizations(self):
|
| 408 |
"""Generate performance visualizations"""
|
| 409 |
-
viz_dir = self.
|
| 410 |
viz_dir.mkdir(exist_ok=True)
|
| 411 |
|
| 412 |
# This would generate charts comparing policies
|
|
@@ -442,7 +461,7 @@ This intelligent court scheduling system uses Reinforcement Learning to optimize
|
|
| 442 |
**{disposal_rate:.1%} Case Disposal Rate** - Significantly improved case clearance
|
| 443 |
**{result.utilization:.1%} Court Utilization** - Optimal resource allocation
|
| 444 |
**{result.hearings_total:,} Hearings Scheduled** - Over {self.config.sim_days} days
|
| 445 |
-
**AI-Powered Decisions** - Reinforcement learning with {self.config.episodes} training episodes
|
| 446 |
|
| 447 |
### Technical Innovation
|
| 448 |
|
|
@@ -493,9 +512,15 @@ def get_interactive_config() -> PipelineConfig:
|
|
| 493 |
|
| 494 |
# RL Training
|
| 495 |
console.print("\n[bold]RL Training[/bold]")
|
|
|
|
|
|
|
| 496 |
episodes = IntPrompt.ask("Training episodes", default=100)
|
| 497 |
learning_rate = FloatPrompt.ask("Learning rate", default=0.15)
|
| 498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
# Simulation
|
| 500 |
console.print("\n[bold]Simulation[/bold]")
|
| 501 |
sim_days = IntPrompt.ask("Simulation days (730 = 2 years)", default=730)
|
|
@@ -514,14 +539,11 @@ def get_interactive_config() -> PipelineConfig:
|
|
| 514 |
n_cases=n_cases,
|
| 515 |
start_date=start_date,
|
| 516 |
end_date=end_date,
|
| 517 |
-
|
| 518 |
-
learning_rate=learning_rate,
|
| 519 |
sim_days=sim_days,
|
| 520 |
policies=policies,
|
| 521 |
-
output_dir=output_dir,
|
| 522 |
generate_cause_lists=generate_cause_lists,
|
| 523 |
-
generate_visualizations=generate_visualizations
|
| 524 |
-
)
|
| 525 |
|
| 526 |
@app.command()
|
| 527 |
def interactive():
|
|
@@ -532,7 +554,8 @@ def interactive():
|
|
| 532 |
console.print(f"\n[bold yellow]Configuration Summary:[/bold yellow]")
|
| 533 |
console.print(f" Cases: {config.n_cases:,}")
|
| 534 |
console.print(f" Period: {config.start_date} to {config.end_date}")
|
| 535 |
-
console.print(f" RL Episodes: {config.episodes}")
|
|
|
|
| 536 |
console.print(f" Simulation: {config.sim_days} days")
|
| 537 |
console.print(f" Policies: {', '.join(config.policies)}")
|
| 538 |
console.print(f" Output: {config.output_dir}")
|
|
@@ -561,12 +584,12 @@ def quick():
|
|
| 561 |
"""Run quick demo with default parameters"""
|
| 562 |
console.print("[bold blue]Quick Demo Pipeline[/bold blue]\n")
|
| 563 |
|
|
|
|
|
|
|
| 564 |
config = PipelineConfig(
|
| 565 |
n_cases=10000,
|
| 566 |
-
|
| 567 |
-
sim_days=90
|
| 568 |
-
output_dir="data/quick_demo",
|
| 569 |
-
)
|
| 570 |
|
| 571 |
pipeline = InteractivePipeline(config)
|
| 572 |
pipeline.run()
|
|
|
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Dict, Any, Optional, List
|
| 15 |
import argparse
|
| 16 |
+
from dataclasses import dataclass, asdict, field
|
| 17 |
|
| 18 |
import typer
|
| 19 |
from rich.console import Console
|
|
|
|
| 38 |
stage_mix: str = "auto"
|
| 39 |
seed: int = 42
|
| 40 |
|
| 41 |
+
# RL Training - delegate to RLTrainingConfig
|
| 42 |
+
rl_training: "RLTrainingConfig" = None # Will be set in __post_init__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Simulation
|
| 45 |
sim_days: int = 730 # 2 years
|
| 46 |
sim_start_date: Optional[str] = None
|
| 47 |
policies: List[str] = None
|
| 48 |
|
| 49 |
+
# Output (no longer user-configurable - managed by OutputManager)
|
|
|
|
| 50 |
generate_cause_lists: bool = True
|
| 51 |
generate_visualizations: bool = True
|
| 52 |
|
| 53 |
def __post_init__(self):
|
| 54 |
if self.policies is None:
|
| 55 |
self.policies = ["readiness", "rl"]
|
| 56 |
+
|
| 57 |
+
# Import here to avoid circular dependency
|
| 58 |
+
if self.rl_training is None:
|
| 59 |
+
from rl.config import DEFAULT_RL_TRAINING_CONFIG
|
| 60 |
+
self.rl_training = DEFAULT_RL_TRAINING_CONFIG
|
| 61 |
|
| 62 |
class InteractivePipeline:
|
| 63 |
"""Interactive pipeline orchestrator"""
|
| 64 |
|
| 65 |
+
def __init__(self, config: PipelineConfig, run_id: str = None):
|
| 66 |
self.config = config
|
| 67 |
+
|
| 68 |
+
from scheduler.utils.output_manager import OutputManager
|
| 69 |
+
self.output = OutputManager(run_id=run_id)
|
| 70 |
+
self.output.create_structure()
|
| 71 |
+
self.output.save_config(config)
|
| 72 |
|
| 73 |
def run(self):
|
| 74 |
"""Execute complete pipeline"""
|
|
|
|
| 109 |
with Progress(
|
| 110 |
SpinnerColumn(),
|
| 111 |
TextColumn("[progress.description]{task.description}"),
|
| 112 |
+
console=console) as progress:
|
|
|
|
| 113 |
task = progress.add_task("Running EDA pipeline...", total=None)
|
| 114 |
|
| 115 |
+
# Configure EDA output paths
|
| 116 |
+
from src.eda_config import set_output_paths
|
| 117 |
+
set_output_paths(
|
| 118 |
+
eda_dir=self.output.eda_figures,
|
| 119 |
+
data_dir=self.output.eda_data,
|
| 120 |
+
params_dir=self.output.eda_params
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
from src.eda_load_clean import run_load_and_clean
|
| 124 |
from src.eda_exploration import run_exploration
|
| 125 |
from src.eda_parameters import run_parameter_export
|
|
|
|
| 137 |
console.print(f"\n[bold cyan]Step 2/7: Data Generation[/bold cyan]")
|
| 138 |
console.print(f" Generating {self.config.n_cases:,} cases ({self.config.start_date} to {self.config.end_date})")
|
| 139 |
|
| 140 |
+
cases_file = self.output.training_cases_file
|
| 141 |
|
| 142 |
with Progress(
|
| 143 |
SpinnerColumn(),
|
| 144 |
TextColumn("[progress.description]{task.description}"),
|
| 145 |
BarColumn(),
|
| 146 |
+
console=console) as progress:
|
|
|
|
| 147 |
task = progress.add_task("Generating cases...", total=100)
|
| 148 |
|
| 149 |
from datetime import date as date_cls
|
|
|
|
| 166 |
def _step_3_rl_training(self):
|
| 167 |
"""Step 3: RL Agent Training"""
|
| 168 |
console.print(f"\n[bold cyan]Step 3/7: RL Training[/bold cyan]")
|
| 169 |
+
console.print(f" Episodes: {self.config.rl_training.episodes}, Learning Rate: {self.config.rl_training.learning_rate}")
|
| 170 |
|
| 171 |
+
model_file = self.output.trained_model_file
|
| 172 |
|
| 173 |
with Progress(
|
| 174 |
SpinnerColumn(),
|
| 175 |
TextColumn("[progress.description]{task.description}"),
|
| 176 |
BarColumn(),
|
| 177 |
TimeElapsedColumn(),
|
| 178 |
+
console=console) as progress:
|
| 179 |
+
training_task = progress.add_task("Training RL agent...", total=self.config.rl_training.episodes)
|
|
|
|
| 180 |
|
| 181 |
# Import training components
|
| 182 |
from rl.training import train_agent
|
| 183 |
from rl.simple_agent import TabularQAgent
|
| 184 |
import pickle
|
| 185 |
|
| 186 |
+
# Initialize agent with configured hyperparameters
|
| 187 |
+
rl_cfg = self.config.rl_training
|
| 188 |
agent = TabularQAgent(
|
| 189 |
+
learning_rate=rl_cfg.learning_rate,
|
| 190 |
+
epsilon=rl_cfg.initial_epsilon,
|
| 191 |
+
discount=rl_cfg.discount_factor
|
| 192 |
)
|
| 193 |
|
| 194 |
# Training with progress updates
|
| 195 |
# Note: train_agent handles its own progress internally
|
| 196 |
+
rl_cfg = self.config.rl_training
|
| 197 |
training_stats = train_agent(
|
| 198 |
agent=agent,
|
| 199 |
+
episodes=rl_cfg.episodes,
|
| 200 |
+
cases_per_episode=rl_cfg.cases_per_episode,
|
| 201 |
+
episode_length=rl_cfg.episode_length_days,
|
| 202 |
verbose=False # Disable internal printing
|
| 203 |
)
|
| 204 |
|
| 205 |
+
progress.update(training_task, completed=rl_cfg.episodes)
|
| 206 |
|
| 207 |
# Save trained agent
|
| 208 |
agent.save(model_file)
|
| 209 |
|
| 210 |
+
# Create symlink in models/ for backwards compatibility
|
| 211 |
+
self.output.create_model_symlink()
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
console.print(f" [green]OK[/green] Training complete -> {model_file}")
|
| 214 |
+
console.print(f" [green]OK[/green] Model symlink: models/latest.pkl")
|
| 215 |
console.print(f" [green]OK[/green] Final epsilon: {agent.epsilon:.4f}, States explored: {len(agent.q_table)}")
|
| 216 |
+
|
| 217 |
+
# Store model path for simulation step
|
| 218 |
+
self.trained_model_path = model_file
|
| 219 |
|
| 220 |
def _step_4_simulation(self):
|
| 221 |
"""Step 4: 2-Year Simulation"""
|
|
|
|
| 223 |
console.print(f" Duration: {self.config.sim_days} days ({self.config.sim_days/365:.1f} years)")
|
| 224 |
|
| 225 |
# Load cases
|
| 226 |
+
cases_file = self.output.training_cases_file
|
| 227 |
from scheduler.data.case_generator import CaseGenerator
|
| 228 |
cases = CaseGenerator.from_csv(cases_file)
|
| 229 |
|
|
|
|
| 235 |
for policy in self.config.policies:
|
| 236 |
console.print(f"\n Running {policy} policy simulation...")
|
| 237 |
|
| 238 |
+
policy_dir = self.output.get_policy_dir(policy)
|
| 239 |
policy_dir.mkdir(exist_ok=True)
|
| 240 |
|
| 241 |
+
# CRITICAL: Deep copy cases for each simulation to prevent state pollution
|
| 242 |
+
# Cases are mutated during simulation (status, hearing_count, disposal_date)
|
| 243 |
+
from copy import deepcopy
|
| 244 |
+
policy_cases = deepcopy(cases)
|
| 245 |
+
|
| 246 |
with Progress(
|
| 247 |
SpinnerColumn(),
|
| 248 |
TextColumn(f"[progress.description]Simulating {policy}..."),
|
| 249 |
BarColumn(),
|
| 250 |
+
console=console) as progress:
|
|
|
|
| 251 |
task = progress.add_task("Simulating...", total=100)
|
| 252 |
|
| 253 |
from scheduler.simulation.engine import CourtSim, CourtSimConfig
|
| 254 |
|
| 255 |
+
# Prepare config with RL model path if needed
|
| 256 |
+
cfg_kwargs = {
|
| 257 |
+
"start": sim_start,
|
| 258 |
+
"days": self.config.sim_days,
|
| 259 |
+
"seed": self.config.seed,
|
| 260 |
+
"policy": policy,
|
| 261 |
+
"duration_percentile": "median",
|
| 262 |
+
"log_dir": policy_dir,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
# Add RL agent path for RL policy
|
| 266 |
+
if policy == "rl" and hasattr(self, 'trained_model_path'):
|
| 267 |
+
cfg_kwargs["rl_agent_path"] = self.trained_model_path
|
| 268 |
|
| 269 |
+
cfg = CourtSimConfig(**cfg_kwargs)
|
| 270 |
+
|
| 271 |
+
sim = CourtSim(cfg, policy_cases)
|
| 272 |
result = sim.run()
|
| 273 |
|
| 274 |
progress.update(task, completed=100)
|
| 275 |
|
| 276 |
results[policy] = {
|
| 277 |
'result': result,
|
| 278 |
+
'cases': policy_cases, # Use the deep-copied cases for this simulation
|
| 279 |
'sim': sim,
|
| 280 |
'dir': policy_dir
|
| 281 |
}
|
|
|
|
| 299 |
with Progress(
|
| 300 |
SpinnerColumn(),
|
| 301 |
TextColumn("[progress.description]{task.description}"),
|
| 302 |
+
console=console) as progress:
|
|
|
|
| 303 |
task = progress.add_task("Generating cause lists...", total=None)
|
| 304 |
|
| 305 |
from scheduler.output.cause_list import CauseListGenerator
|
|
|
|
| 323 |
with Progress(
|
| 324 |
SpinnerColumn(),
|
| 325 |
TextColumn("[progress.description]{task.description}"),
|
| 326 |
+
console=console) as progress:
|
|
|
|
| 327 |
task = progress.add_task("Analyzing results...", total=None)
|
| 328 |
|
| 329 |
# Generate comparison report
|
|
|
|
| 344 |
summary = self._generate_executive_summary()
|
| 345 |
|
| 346 |
# Save summary
|
| 347 |
+
summary_file = self.output.executive_summary_file
|
| 348 |
with open(summary_file, 'w') as f:
|
| 349 |
f.write(summary)
|
| 350 |
|
|
|
|
| 387 |
|
| 388 |
console.print(Panel.fit(
|
| 389 |
f"[bold green]Pipeline Complete![/bold green]\n\n"
|
| 390 |
+
f"Results: {self.output.run_dir}/\n"
|
| 391 |
f"Executive Summary: {summary_file}\n"
|
| 392 |
+
f"Visualizations: {self.output.visualizations_dir}/\n"
|
| 393 |
+
f"Cause Lists: {self.output.simulation_dir}/*/cause_lists/\n\n"
|
| 394 |
f"[yellow]Ready for hackathon submission![/yellow]",
|
| 395 |
box=box.DOUBLE_EDGE
|
| 396 |
))
|
| 397 |
|
| 398 |
def _generate_comparison_report(self):
|
| 399 |
"""Generate detailed comparison report"""
|
| 400 |
+
report_file = self.output.comparison_report_file
|
| 401 |
|
| 402 |
with open(report_file, 'w') as f:
|
| 403 |
f.write("# Court Scheduling System - Performance Comparison\n\n")
|
|
|
|
| 406 |
f.write("## Configuration\n\n")
|
| 407 |
f.write(f"- Training Cases: {self.config.n_cases:,}\n")
|
| 408 |
f.write(f"- Simulation Period: {self.config.sim_days} days ({self.config.sim_days/365:.1f} years)\n")
|
| 409 |
+
f.write(f"- RL Episodes: {self.config.rl_training.episodes}\n")
|
| 410 |
+
f.write(f"- RL Learning Rate: {self.config.rl_training.learning_rate}\n")
|
| 411 |
+
f.write(f"- RL Epsilon: {self.config.rl_training.initial_epsilon}\n")
|
| 412 |
f.write(f"- Policies Compared: {', '.join(self.config.policies)}\n\n")
|
| 413 |
|
| 414 |
f.write("## Results Summary\n\n")
|
|
|
|
| 425 |
|
| 426 |
def _generate_visualizations(self):
|
| 427 |
"""Generate performance visualizations"""
|
| 428 |
+
viz_dir = self.output.visualizations_dir
|
| 429 |
viz_dir.mkdir(exist_ok=True)
|
| 430 |
|
| 431 |
# This would generate charts comparing policies
|
|
|
|
| 461 |
**{disposal_rate:.1%} Case Disposal Rate** - Significantly improved case clearance
|
| 462 |
**{result.utilization:.1%} Court Utilization** - Optimal resource allocation
|
| 463 |
**{result.hearings_total:,} Hearings Scheduled** - Over {self.config.sim_days} days
|
| 464 |
+
**AI-Powered Decisions** - Reinforcement learning with {self.config.rl_training.episodes} training episodes
|
| 465 |
|
| 466 |
### Technical Innovation
|
| 467 |
|
|
|
|
| 512 |
|
| 513 |
# RL Training
|
| 514 |
console.print("\n[bold]RL Training[/bold]")
|
| 515 |
+
from rl.config import RLTrainingConfig
|
| 516 |
+
|
| 517 |
episodes = IntPrompt.ask("Training episodes", default=100)
|
| 518 |
learning_rate = FloatPrompt.ask("Learning rate", default=0.15)
|
| 519 |
|
| 520 |
+
rl_training_config = RLTrainingConfig(
|
| 521 |
+
episodes=episodes,
|
| 522 |
+
learning_rate=learning_rate)
|
| 523 |
+
|
| 524 |
# Simulation
|
| 525 |
console.print("\n[bold]Simulation[/bold]")
|
| 526 |
sim_days = IntPrompt.ask("Simulation days (730 = 2 years)", default=730)
|
|
|
|
| 539 |
n_cases=n_cases,
|
| 540 |
start_date=start_date,
|
| 541 |
end_date=end_date,
|
| 542 |
+
rl_training=rl_training_config,
|
|
|
|
| 543 |
sim_days=sim_days,
|
| 544 |
policies=policies,
|
|
|
|
| 545 |
generate_cause_lists=generate_cause_lists,
|
| 546 |
+
generate_visualizations=generate_visualizations)
|
|
|
|
| 547 |
|
| 548 |
@app.command()
|
| 549 |
def interactive():
|
|
|
|
| 554 |
console.print(f"\n[bold yellow]Configuration Summary:[/bold yellow]")
|
| 555 |
console.print(f" Cases: {config.n_cases:,}")
|
| 556 |
console.print(f" Period: {config.start_date} to {config.end_date}")
|
| 557 |
+
console.print(f" RL Episodes: {config.rl_training.episodes}")
|
| 558 |
+
console.print(f" RL Learning Rate: {config.rl_training.learning_rate}")
|
| 559 |
console.print(f" Simulation: {config.sim_days} days")
|
| 560 |
console.print(f" Policies: {', '.join(config.policies)}")
|
| 561 |
console.print(f" Output: {config.output_dir}")
|
|
|
|
| 584 |
"""Run quick demo with default parameters"""
|
| 585 |
console.print("[bold blue]Quick Demo Pipeline[/bold blue]\n")
|
| 586 |
|
| 587 |
+
from rl.config import QUICK_DEMO_RL_CONFIG
|
| 588 |
+
|
| 589 |
config = PipelineConfig(
|
| 590 |
n_cases=10000,
|
| 591 |
+
rl_training=QUICK_DEMO_RL_CONFIG,
|
| 592 |
+
sim_days=90)
|
|
|
|
|
|
|
| 593 |
|
| 594 |
pipeline = InteractivePipeline(config)
|
| 595 |
pipeline.run()
|
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration Architecture
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
The codebase uses a layered configuration approach separating concerns by domain and lifecycle.
|
| 5 |
+
|
| 6 |
+
## Configuration Layers
|
| 7 |
+
|
| 8 |
+
### 1. Domain Constants (`scheduler/data/config.py`)
|
| 9 |
+
**Purpose**: Immutable domain knowledge that never changes.
|
| 10 |
+
|
| 11 |
+
**Contains**:
|
| 12 |
+
- `STAGES` - Legal case lifecycle stages from domain knowledge
|
| 13 |
+
- `TERMINAL_STAGES` - Stages indicating case disposal
|
| 14 |
+
- `CASE_TYPES` - Valid case type taxonomy
|
| 15 |
+
- `CASE_TYPE_DISTRIBUTION` - Historical distribution from EDA
|
| 16 |
+
- `WORKING_DAYS_PER_YEAR` - Court calendar constant (192 days)
|
| 17 |
+
|
| 18 |
+
**When to use**: Values derived from legal/institutional domain that are facts, not tunable parameters.
|
| 19 |
+
|
| 20 |
+
### 2. RL Training Configuration (`rl/config.py`)
|
| 21 |
+
**Purpose**: Hyperparameters affecting RL agent learning behavior.
|
| 22 |
+
|
| 23 |
+
**Class**: `RLTrainingConfig`
|
| 24 |
+
|
| 25 |
+
**Parameters**:
|
| 26 |
+
- `episodes`: Number of training episodes
|
| 27 |
+
- `cases_per_episode`: Cases generated per episode
|
| 28 |
+
- `episode_length_days`: Simulation horizon per episode
|
| 29 |
+
- `learning_rate`: Q-learning alpha parameter
|
| 30 |
+
- `discount_factor`: Q-learning gamma parameter
|
| 31 |
+
- `initial_epsilon`: Starting exploration rate
|
| 32 |
+
- `epsilon_decay`: Exploration decay factor
|
| 33 |
+
- `min_epsilon`: Minimum exploration threshold
|
| 34 |
+
|
| 35 |
+
**Presets**:
|
| 36 |
+
- `DEFAULT_RL_TRAINING_CONFIG` - Standard training (100 episodes)
|
| 37 |
+
- `QUICK_DEMO_RL_CONFIG` - Fast testing (20 episodes)
|
| 38 |
+
|
| 39 |
+
**When to use**: Experimenting with RL training convergence and exploration strategies.
|
| 40 |
+
|
| 41 |
+
### 3. Policy Configuration (`rl/config.py`)
|
| 42 |
+
**Purpose**: Policy-specific filtering and prioritization behavior.
|
| 43 |
+
|
| 44 |
+
**Class**: `PolicyConfig`
|
| 45 |
+
|
| 46 |
+
**Parameters**:
|
| 47 |
+
- `min_gap_days`: Minimum days between hearings (fairness constraint)
|
| 48 |
+
- `max_gap_alert_days`: Maximum gap before triggering alerts
|
| 49 |
+
- `old_case_threshold_days`: Age threshold for priority boost
|
| 50 |
+
- `skip_unripe_cases`: Whether to filter unripe cases
|
| 51 |
+
- `allow_old_unripe_cases`: Allow scheduling very old unripe cases
|
| 52 |
+
|
| 53 |
+
**When to use**: Tuning policy filtering logic without changing core algorithm.
|
| 54 |
+
|
| 55 |
+
### 4. Simulation Configuration (`scheduler/simulation/engine.py`)
|
| 56 |
+
**Purpose**: Per-simulation operational parameters.
|
| 57 |
+
|
| 58 |
+
**Class**: `CourtSimConfig`
|
| 59 |
+
|
| 60 |
+
**Parameters**:
|
| 61 |
+
- `start`: Simulation start date
|
| 62 |
+
- `days`: Duration in days
|
| 63 |
+
- `seed`: Random seed for reproducibility
|
| 64 |
+
- `courtrooms`: Number of courtrooms to simulate
|
| 65 |
+
- `daily_capacity`: Cases per courtroom per day
|
| 66 |
+
- `policy`: Scheduling policy name (`fifo`, `age`, `readiness`, `rl`)
|
| 67 |
+
- `duration_percentile`: EDA percentile for stage durations
|
| 68 |
+
- `rl_agent_path`: Path to trained RL model (required if `policy="rl"`)
|
| 69 |
+
- `log_dir`: Output directory for metrics
|
| 70 |
+
|
| 71 |
+
**Validation**: `__post_init__` validates RL requirements and path types.
|
| 72 |
+
|
| 73 |
+
**When to use**: Each simulation run (different policies, time periods, or capacities).
|
| 74 |
+
|
| 75 |
+
### 5. Pipeline Configuration (`court_scheduler_rl.py`)
|
| 76 |
+
**Purpose**: Orchestrating multi-step workflow execution.
|
| 77 |
+
|
| 78 |
+
**Class**: `PipelineConfig`
|
| 79 |
+
|
| 80 |
+
**Parameters**:
|
| 81 |
+
- `n_cases`: Cases to generate for training
|
| 82 |
+
- `start_date`/`end_date`: Training data time window
|
| 83 |
+
- `rl_training`: RLTrainingConfig instance
|
| 84 |
+
- `sim_days`: Simulation duration
|
| 85 |
+
- `policies`: List of policies to compare
|
| 86 |
+
- `output_dir`: Results output location
|
| 87 |
+
- `generate_cause_lists`/`generate_visualizations`: Output options
|
| 88 |
+
|
| 89 |
+
**When to use**: Running complete training→simulation→analysis workflows.
|
| 90 |
+
|
| 91 |
+
## Configuration Flow
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
Pipeline Execution:
|
| 95 |
+
├── PipelineConfig (workflow orchestration)
|
| 96 |
+
│ ├── RLTrainingConfig (training hyperparameters)
|
| 97 |
+
│ └── Data generation params
|
| 98 |
+
│
|
| 99 |
+
└── Per-Policy Simulation:
|
| 100 |
+
├── CourtSimConfig (simulation settings)
|
| 101 |
+
│ └── rl_agent_path (from training output)
|
| 102 |
+
│
|
| 103 |
+
└── Policy instantiation:
|
| 104 |
+
└── PolicyConfig (policy-specific settings)
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## Design Principles
|
| 108 |
+
|
| 109 |
+
1. **Separation of Concerns**: Each config class owns one domain
|
| 110 |
+
2. **Type Safety**: Dataclasses with validation in `__post_init__`
|
| 111 |
+
3. **No Magic**: Explicit parameters, no hidden defaults
|
| 112 |
+
4. **Immutability**: Domain constants never change
|
| 113 |
+
5. **Composition**: Configs nest (PipelineConfig contains RLTrainingConfig)
|
| 114 |
+
|
| 115 |
+
## Examples
|
| 116 |
+
|
| 117 |
+
### Quick Demo
|
| 118 |
+
```python
|
| 119 |
+
from rl.config import QUICK_DEMO_RL_CONFIG
|
| 120 |
+
|
| 121 |
+
config = PipelineConfig(
|
| 122 |
+
n_cases=10000,
|
| 123 |
+
rl_training=QUICK_DEMO_RL_CONFIG, # 20 episodes
|
| 124 |
+
sim_days=90,
|
| 125 |
+
output_dir="data/quick_demo"
|
| 126 |
+
)
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### Custom Training
|
| 130 |
+
```python
|
| 131 |
+
from rl.config import RLTrainingConfig
|
| 132 |
+
|
| 133 |
+
custom_rl = RLTrainingConfig(
|
| 134 |
+
episodes=500,
|
| 135 |
+
learning_rate=0.1,
|
| 136 |
+
initial_epsilon=0.3,
|
| 137 |
+
epsilon_decay=0.995
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
config = PipelineConfig(
|
| 141 |
+
n_cases=50000,
|
| 142 |
+
rl_training=custom_rl,
|
| 143 |
+
sim_days=730
|
| 144 |
+
)
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
### Policy Tuning
|
| 148 |
+
```python
|
| 149 |
+
from rl.config import PolicyConfig
|
| 150 |
+
|
| 151 |
+
strict_policy = PolicyConfig(
|
| 152 |
+
min_gap_days=14, # More conservative
|
| 153 |
+
skip_unripe_cases=True,
|
| 154 |
+
allow_old_unripe_cases=False # Strict ripeness enforcement
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Pass to RLPolicy
|
| 158 |
+
policy = RLPolicy(agent_path=model_path, policy_config=strict_policy)
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
## Migration Guide
|
| 162 |
+
|
| 163 |
+
### Adding New Configuration
|
| 164 |
+
1. Determine layer (domain constant vs. tunable parameter)
|
| 165 |
+
2. Add to appropriate config class
|
| 166 |
+
3. Update `__post_init__` validation if needed
|
| 167 |
+
4. Document in this file
|
| 168 |
+
|
| 169 |
+
### Deprecating Parameters
|
| 170 |
+
1. Move to config class first (keep old path working)
|
| 171 |
+
2. Add deprecation warning
|
| 172 |
+
3. Remove old path after one release cycle
|
| 173 |
+
|
| 174 |
+
## Validation Rules
|
| 175 |
+
|
| 176 |
+
All config classes validate in `__post_init__`:
|
| 177 |
+
- Value ranges (0 < learning_rate ≤ 1)
|
| 178 |
+
- Type consistency (convert strings to Path)
|
| 179 |
+
- Cross-parameter constraints (max_gap ≥ min_gap)
|
| 180 |
+
- Required file existence (rl_agent_path must exist)
|
| 181 |
+
|
| 182 |
+
## Anti-Patterns
|
| 183 |
+
|
| 184 |
+
**DON'T**:
|
| 185 |
+
- ❌ Hardcode magic numbers in algorithms
|
| 186 |
+
- ❌ Use module-level mutable globals
|
| 187 |
+
- ❌ Mix domain constants with tunable parameters
|
| 188 |
+
- ❌ Create "god config" with everything in one class
|
| 189 |
+
|
| 190 |
+
**DO**:
|
| 191 |
+
- ✓ Separate by lifecycle and ownership
|
| 192 |
+
- ✓ Validate early (constructor time)
|
| 193 |
+
- ✓ Use dataclasses for immutability
|
| 194 |
+
- ✓ Provide sensible defaults with named presets
|
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Output Directory Refactoring - Implementation Status
|
| 2 |
+
|
| 3 |
+
## Completed
|
| 4 |
+
|
| 5 |
+
### 1. Created `OutputManager` class
|
| 6 |
+
- **File**: `scheduler/utils/output_manager.py`
|
| 7 |
+
- **Features**:
|
| 8 |
+
- Single run directory with timestamp-based ID
|
| 9 |
+
- Clean hierarchy: `eda/` `training/` `simulation/` `reports/`
|
| 10 |
+
- Property-based access to all output paths
|
| 11 |
+
- Config saved to run root for reproducibility
|
| 12 |
+
|
| 13 |
+
### 2. Integrated into Pipeline
|
| 14 |
+
- **File**: `court_scheduler_rl.py`
|
| 15 |
+
- **Changes**:
|
| 16 |
+
- `PipelineConfig` no longer has `output_dir` field
|
| 17 |
+
- `InteractivePipeline` uses `OutputManager` instance
|
| 18 |
+
- All `self.output_dir` references replaced with `self.output.{property}`
|
| 19 |
+
- Pipeline compiles successfully
|
| 20 |
+
|
| 21 |
+
## Completed Tasks
|
| 22 |
+
|
| 23 |
+
### 1. Remove Duplicate Model Saving (DONE)
|
| 24 |
+
- Removed duplicate model save in court_scheduler_rl.py
|
| 25 |
+
- Implemented `OutputManager.create_model_symlink()` method
|
| 26 |
+
- Model saved once to `outputs/runs/{run_id}/training/agent.pkl`
|
| 27 |
+
- Symlink created at `models/latest.pkl`
|
| 28 |
+
|
| 29 |
+
### 2. Update EDA Output Paths (DONE)
|
| 30 |
+
- Modified `src/eda_config.py` with:
|
| 31 |
+
- `set_output_paths()` function to configure from OutputManager
|
| 32 |
+
- Private getter functions (`_get_run_dir()`, `_get_params_dir()`, etc.)
|
| 33 |
+
- Fallback to legacy paths when running standalone
|
| 34 |
+
- Updated all EDA modules (eda_load_clean.py, eda_exploration.py, eda_parameters.py)
|
| 35 |
+
- Pipeline calls `set_output_paths()` before running EDA steps
|
| 36 |
+
- EDA outputs now write to `outputs/runs/{run_id}/eda/`
|
| 37 |
+
|
| 38 |
+
### 3. Fix Import Errors (DONE)
|
| 39 |
+
- Fixed syntax errors in EDA imports (removed parentheses from function names)
|
| 40 |
+
- All modules compile without errors
|
| 41 |
+
|
| 42 |
+
### 4. Test End-to-End (DONE)
|
| 43 |
+
```bash
|
| 44 |
+
uv run python court_scheduler_rl.py quick
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
**Status**: SUCCESS (Exit code: 0)
|
| 48 |
+
- All outputs in `outputs/runs/run_20251126_055943/`
|
| 49 |
+
- No scattered files
|
| 50 |
+
- Models symlinked correctly at `models/latest.pkl`
|
| 51 |
+
- Pipeline runs without errors
|
| 52 |
+
- Clean directory structure verified with `tree` command
|
| 53 |
+
|
| 54 |
+
## New Directory Structure
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
outputs/
|
| 58 |
+
└── runs/
|
| 59 |
+
└── run_20251126_123456/
|
| 60 |
+
├── config.json
|
| 61 |
+
├── eda/
|
| 62 |
+
│ ├── figures/
|
| 63 |
+
│ ├── params/
|
| 64 |
+
│ └── data/
|
| 65 |
+
├── training/
|
| 66 |
+
│ ├── cases.csv
|
| 67 |
+
│ ├── agent.pkl
|
| 68 |
+
│ └── stats.json
|
| 69 |
+
├── simulation/
|
| 70 |
+
│ ├── readiness/
|
| 71 |
+
│ └── rl/
|
| 72 |
+
└── reports/
|
| 73 |
+
├── EXECUTIVE_SUMMARY.md
|
| 74 |
+
├── COMPARISON_REPORT.md
|
| 75 |
+
└── visualizations/
|
| 76 |
+
|
| 77 |
+
models/
|
| 78 |
+
└── latest.pkl -> ../outputs/runs/run_20251126_123456/training/agent.pkl
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Benefits Achieved
|
| 82 |
+
|
| 83 |
+
1. **Single source of truth**: All run artifacts in one directory
|
| 84 |
+
2. **Reproducibility**: Config saved with outputs
|
| 85 |
+
3. **No duplication**: Files written once, not copied
|
| 86 |
+
4. **Clear hierarchy**: Logical organization by pipeline phase
|
| 87 |
+
5. **Easy cleanup**: Delete entire run directory
|
| 88 |
+
6. **Version control**: Run IDs sortable by timestamp
|
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RL training configuration and hyperparameters.
|
| 2 |
+
|
| 3 |
+
This module contains all configurable parameters for RL agent training,
|
| 4 |
+
separate from domain constants and simulation settings.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class RLTrainingConfig:
|
| 12 |
+
"""Configuration for RL agent training.
|
| 13 |
+
|
| 14 |
+
Hyperparameters that affect learning behavior and convergence.
|
| 15 |
+
"""
|
| 16 |
+
# Training episodes
|
| 17 |
+
episodes: int = 100
|
| 18 |
+
cases_per_episode: int = 1000
|
| 19 |
+
episode_length_days: int = 60
|
| 20 |
+
|
| 21 |
+
# Q-learning hyperparameters
|
| 22 |
+
learning_rate: float = 0.15
|
| 23 |
+
discount_factor: float = 0.95
|
| 24 |
+
|
| 25 |
+
# Exploration strategy
|
| 26 |
+
initial_epsilon: float = 0.4
|
| 27 |
+
epsilon_decay: float = 0.99
|
| 28 |
+
min_epsilon: float = 0.05
|
| 29 |
+
|
| 30 |
+
# Training data generation
|
| 31 |
+
training_seed: int = 42
|
| 32 |
+
stage_mix_auto: bool = True # Use EDA-derived stage distribution
|
| 33 |
+
|
| 34 |
+
def __post_init__(self):
|
| 35 |
+
"""Validate configuration parameters."""
|
| 36 |
+
if not (0.0 < self.learning_rate <= 1.0):
|
| 37 |
+
raise ValueError(f"learning_rate must be in (0, 1], got {self.learning_rate}")
|
| 38 |
+
|
| 39 |
+
if not (0.0 <= self.discount_factor <= 1.0):
|
| 40 |
+
raise ValueError(f"discount_factor must be in [0, 1], got {self.discount_factor}")
|
| 41 |
+
|
| 42 |
+
if not (0.0 <= self.initial_epsilon <= 1.0):
|
| 43 |
+
raise ValueError(f"initial_epsilon must be in [0, 1], got {self.initial_epsilon}")
|
| 44 |
+
|
| 45 |
+
if self.episodes < 1:
|
| 46 |
+
raise ValueError(f"episodes must be >= 1, got {self.episodes}")
|
| 47 |
+
|
| 48 |
+
if self.cases_per_episode < 1:
|
| 49 |
+
raise ValueError(f"cases_per_episode must be >= 1, got {self.cases_per_episode}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class PolicyConfig:
|
| 54 |
+
"""Configuration for scheduling policy behavior.
|
| 55 |
+
|
| 56 |
+
Settings that affect how policies prioritize and filter cases.
|
| 57 |
+
"""
|
| 58 |
+
# Minimum gap between hearings (days)
|
| 59 |
+
min_gap_days: int = 7 # From MIN_GAP_BETWEEN_HEARINGS in config.py
|
| 60 |
+
|
| 61 |
+
# Maximum gap before alert (days)
|
| 62 |
+
max_gap_alert_days: int = 90 # From MAX_GAP_WITHOUT_ALERT
|
| 63 |
+
|
| 64 |
+
# Old case threshold for priority boost (days)
|
| 65 |
+
old_case_threshold_days: int = 180
|
| 66 |
+
|
| 67 |
+
# Ripeness filtering
|
| 68 |
+
skip_unripe_cases: bool = True
|
| 69 |
+
allow_old_unripe_cases: bool = True # Allow scheduling if age > old_case_threshold
|
| 70 |
+
|
| 71 |
+
def __post_init__(self):
|
| 72 |
+
"""Validate configuration parameters."""
|
| 73 |
+
if self.min_gap_days < 0:
|
| 74 |
+
raise ValueError(f"min_gap_days must be >= 0, got {self.min_gap_days}")
|
| 75 |
+
|
| 76 |
+
if self.max_gap_alert_days < self.min_gap_days:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"max_gap_alert_days ({self.max_gap_alert_days}) must be >= "
|
| 79 |
+
f"min_gap_days ({self.min_gap_days})"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# Default configurations
|
| 84 |
+
DEFAULT_RL_TRAINING_CONFIG = RLTrainingConfig()
|
| 85 |
+
DEFAULT_POLICY_CONFIG = PolicyConfig()
|
| 86 |
+
|
| 87 |
+
# Quick demo configuration (for testing)
|
| 88 |
+
QUICK_DEMO_RL_CONFIG = RLTrainingConfig(
|
| 89 |
+
episodes=20,
|
| 90 |
+
cases_per_episode=1000,
|
| 91 |
+
episode_length_days=45,
|
| 92 |
+
learning_rate=0.15,
|
| 93 |
+
initial_epsilon=0.4,
|
| 94 |
+
)
|
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Centralized output directory management.
|
| 2 |
+
|
| 3 |
+
Provides clean, hierarchical output structure for all pipeline artifacts.
|
| 4 |
+
No scattered files, no duplicate saves, single source of truth per run.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import Optional
|
| 10 |
+
import json
|
| 11 |
+
from dataclasses import asdict
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class OutputManager:
|
| 15 |
+
"""Manages all output paths for a pipeline run.
|
| 16 |
+
|
| 17 |
+
Design principles:
|
| 18 |
+
- Single run directory contains ALL artifacts
|
| 19 |
+
- No copying/moving files between directories
|
| 20 |
+
- Clear hierarchy: eda/ training/ simulation/ reports/
|
| 21 |
+
- Run ID is timestamp-based for sorting
|
| 22 |
+
- Config saved at root for reproducibility
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, run_id: Optional[str] = None, base_dir: Optional[Path] = None):
|
| 26 |
+
"""Initialize output manager for a pipeline run.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
run_id: Unique run identifier (default: timestamp)
|
| 30 |
+
base_dir: Base directory for all outputs (default: outputs/runs)
|
| 31 |
+
"""
|
| 32 |
+
self.run_id = run_id or f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 33 |
+
|
| 34 |
+
# Base paths
|
| 35 |
+
project_root = Path(__file__).parent.parent.parent
|
| 36 |
+
self.base_dir = base_dir or (project_root / "outputs" / "runs")
|
| 37 |
+
self.run_dir = self.base_dir / self.run_id
|
| 38 |
+
|
| 39 |
+
# Primary output directories
|
| 40 |
+
self.eda_dir = self.run_dir / "eda"
|
| 41 |
+
self.training_dir = self.run_dir / "training"
|
| 42 |
+
self.simulation_dir = self.run_dir / "simulation"
|
| 43 |
+
self.reports_dir = self.run_dir / "reports"
|
| 44 |
+
|
| 45 |
+
# EDA subdirectories
|
| 46 |
+
self.eda_figures = self.eda_dir / "figures"
|
| 47 |
+
self.eda_params = self.eda_dir / "params"
|
| 48 |
+
self.eda_data = self.eda_dir / "data"
|
| 49 |
+
|
| 50 |
+
# Reports subdirectories
|
| 51 |
+
self.visualizations_dir = self.reports_dir / "visualizations"
|
| 52 |
+
|
| 53 |
+
def create_structure(self):
|
| 54 |
+
"""Create all output directories."""
|
| 55 |
+
for dir_path in [
|
| 56 |
+
self.run_dir,
|
| 57 |
+
self.eda_dir,
|
| 58 |
+
self.eda_figures,
|
| 59 |
+
self.eda_params,
|
| 60 |
+
self.eda_data,
|
| 61 |
+
self.training_dir,
|
| 62 |
+
self.simulation_dir,
|
| 63 |
+
self.reports_dir,
|
| 64 |
+
self.visualizations_dir,
|
| 65 |
+
]:
|
| 66 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 67 |
+
|
| 68 |
+
def save_config(self, config):
|
| 69 |
+
"""Save pipeline configuration to run directory.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
config: PipelineConfig or any dataclass
|
| 73 |
+
"""
|
| 74 |
+
config_path = self.run_dir / "config.json"
|
| 75 |
+
with open(config_path, 'w') as f:
|
| 76 |
+
# Handle nested dataclasses (like rl_training)
|
| 77 |
+
config_dict = asdict(config) if hasattr(config, '__dataclass_fields__') else config
|
| 78 |
+
json.dump(config_dict, f, indent=2, default=str)
|
| 79 |
+
|
| 80 |
+
def get_policy_dir(self, policy_name: str) -> Path:
|
| 81 |
+
"""Get simulation directory for a specific policy.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
policy_name: Policy name (e.g., 'readiness', 'rl')
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Path to policy simulation directory
|
| 88 |
+
"""
|
| 89 |
+
policy_dir = self.simulation_dir / policy_name
|
| 90 |
+
policy_dir.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
return policy_dir
|
| 92 |
+
|
| 93 |
+
def get_cause_list_dir(self, policy_name: str) -> Path:
|
| 94 |
+
"""Get cause list directory for a policy.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
policy_name: Policy name
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Path to cause list directory
|
| 101 |
+
"""
|
| 102 |
+
cause_list_dir = self.get_policy_dir(policy_name) / "cause_lists"
|
| 103 |
+
cause_list_dir.mkdir(parents=True, exist_ok=True)
|
| 104 |
+
return cause_list_dir
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def training_cases_file(self) -> Path:
|
| 108 |
+
"""Path to generated training cases CSV."""
|
| 109 |
+
return self.training_dir / "cases.csv"
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def trained_model_file(self) -> Path:
|
| 113 |
+
"""Path to trained RL agent model."""
|
| 114 |
+
return self.training_dir / "agent.pkl"
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def training_stats_file(self) -> Path:
|
| 118 |
+
"""Path to training statistics JSON."""
|
| 119 |
+
return self.training_dir / "stats.json"
|
| 120 |
+
|
| 121 |
+
@property
|
| 122 |
+
def executive_summary_file(self) -> Path:
|
| 123 |
+
"""Path to executive summary markdown."""
|
| 124 |
+
return self.reports_dir / "EXECUTIVE_SUMMARY.md"
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def comparison_report_file(self) -> Path:
|
| 128 |
+
"""Path to comparison report markdown."""
|
| 129 |
+
return self.reports_dir / "COMPARISON_REPORT.md"
|
| 130 |
+
|
| 131 |
+
def create_model_symlink(self, alias: str = "latest"):
|
| 132 |
+
"""Create symlink in models/ directory pointing to trained model.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
alias: Symlink name (e.g., 'latest', 'v1.0')
|
| 136 |
+
"""
|
| 137 |
+
project_root = self.run_dir.parent.parent.parent
|
| 138 |
+
models_dir = project_root / "models"
|
| 139 |
+
models_dir.mkdir(exist_ok=True)
|
| 140 |
+
|
| 141 |
+
symlink_path = models_dir / f"{alias}.pkl"
|
| 142 |
+
target = self.trained_model_file
|
| 143 |
+
|
| 144 |
+
# Remove existing symlink if present
|
| 145 |
+
if symlink_path.exists() or symlink_path.is_symlink():
|
| 146 |
+
symlink_path.unlink()
|
| 147 |
+
|
| 148 |
+
# Create symlink (use absolute path for cross-directory links)
|
| 149 |
+
try:
|
| 150 |
+
symlink_path.symlink_to(target.resolve())
|
| 151 |
+
except (OSError, NotImplementedError):
|
| 152 |
+
# Fallback: copy file if symlinks not supported (Windows without dev mode)
|
| 153 |
+
import shutil
|
| 154 |
+
shutil.copy2(target, symlink_path)
|
| 155 |
+
|
| 156 |
+
def __str__(self) -> str:
|
| 157 |
+
return f"OutputManager(run_id='{self.run_id}', run_dir='{self.run_dir}')"
|
| 158 |
+
|
| 159 |
+
def __repr__(self) -> str:
|
| 160 |
+
return self.__str__()
|
|
@@ -8,27 +8,80 @@ from pathlib import Path
|
|
| 8 |
# -------------------------------------------------------------------
|
| 9 |
# Paths and versioning
|
| 10 |
# -------------------------------------------------------------------
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
DUCKDB_FILE = DATA_DIR / "court_data.duckdb"
|
| 13 |
CASES_FILE = DATA_DIR / "ISDMHack_Cases_WPfinal.csv"
|
| 14 |
HEAR_FILE = DATA_DIR / "ISDMHack_Hear.csv"
|
| 15 |
|
| 16 |
-
|
|
|
|
| 17 |
FIGURES_DIR = REPORTS_DIR / "figures"
|
| 18 |
-
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
| 19 |
|
| 20 |
VERSION = "v0.4.0"
|
| 21 |
RUN_TS = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 22 |
|
| 23 |
-
|
| 24 |
-
RUN_DIR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
HEARINGS_CLEAN_PARQUET
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# -------------------------------------------------------------------
|
| 34 |
# Null tokens and canonicalisation
|
|
@@ -37,21 +90,31 @@ NULL_TOKENS = ["", "NULL", "Null", "null", "NA", "N/A", "na", "NaN", "nan", "-",
|
|
| 37 |
|
| 38 |
|
| 39 |
def copy_to_versioned(filename: str) -> None:
|
| 40 |
-
"""
|
| 41 |
-
|
| 42 |
-
dst = RUN_DIR / filename
|
| 43 |
-
try:
|
| 44 |
-
if src.exists():
|
| 45 |
-
shutil.copyfile(src, dst)
|
| 46 |
-
except Exception as e:
|
| 47 |
-
print(f"[WARN] Versioned copy failed for {filename}: {e}")
|
| 48 |
|
| 49 |
|
| 50 |
def write_metadata(meta: dict) -> None:
|
| 51 |
"""Write run metadata into RUN_DIR/metadata.json."""
|
| 52 |
-
|
|
|
|
| 53 |
try:
|
| 54 |
with open(meta_path, "w", encoding="utf-8") as f:
|
| 55 |
json.dump(meta, f, indent=2, default=str)
|
| 56 |
except Exception as e:
|
| 57 |
print(f"[WARN] Metadata export error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# -------------------------------------------------------------------
|
| 9 |
# Paths and versioning
|
| 10 |
# -------------------------------------------------------------------
|
| 11 |
+
# Project root (repo root) = parent of src/
|
| 12 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 13 |
+
|
| 14 |
+
DATA_DIR = PROJECT_ROOT / "Data"
|
| 15 |
DUCKDB_FILE = DATA_DIR / "court_data.duckdb"
|
| 16 |
CASES_FILE = DATA_DIR / "ISDMHack_Cases_WPfinal.csv"
|
| 17 |
HEAR_FILE = DATA_DIR / "ISDMHack_Hear.csv"
|
| 18 |
|
| 19 |
+
# Default paths (used when EDA is run standalone)
|
| 20 |
+
REPORTS_DIR = PROJECT_ROOT / "reports"
|
| 21 |
FIGURES_DIR = REPORTS_DIR / "figures"
|
|
|
|
| 22 |
|
| 23 |
VERSION = "v0.4.0"
|
| 24 |
RUN_TS = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 25 |
|
| 26 |
+
# These will be set by set_output_paths() when running from pipeline
|
| 27 |
+
RUN_DIR = None
|
| 28 |
+
PARAMS_DIR = None
|
| 29 |
+
CASES_CLEAN_PARQUET = None
|
| 30 |
+
HEARINGS_CLEAN_PARQUET = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def set_output_paths(eda_dir: Path, data_dir: Path, params_dir: Path):
|
| 34 |
+
"""Configure output paths from OutputManager.
|
| 35 |
+
|
| 36 |
+
Call this from pipeline before running EDA modules.
|
| 37 |
+
When not called, falls back to legacy reports/figures/ structure.
|
| 38 |
+
"""
|
| 39 |
+
global RUN_DIR, PARAMS_DIR, CASES_CLEAN_PARQUET, HEARINGS_CLEAN_PARQUET
|
| 40 |
+
RUN_DIR = eda_dir
|
| 41 |
+
PARAMS_DIR = params_dir
|
| 42 |
+
CASES_CLEAN_PARQUET = data_dir / "cases_clean.parquet"
|
| 43 |
+
HEARINGS_CLEAN_PARQUET = data_dir / "hearings_clean.parquet"
|
| 44 |
+
|
| 45 |
+
# Ensure directories exist
|
| 46 |
+
RUN_DIR.mkdir(parents=True, exist_ok=True)
|
| 47 |
+
PARAMS_DIR.mkdir(parents=True, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _get_run_dir() -> Path:
|
| 51 |
+
"""Get RUN_DIR, creating default if not set."""
|
| 52 |
+
global RUN_DIR
|
| 53 |
+
if RUN_DIR is None:
|
| 54 |
+
# Standalone mode: use legacy versioned directory
|
| 55 |
+
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
| 56 |
+
RUN_DIR = FIGURES_DIR / f"{VERSION}_{RUN_TS}"
|
| 57 |
+
RUN_DIR.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
return RUN_DIR
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _get_params_dir() -> Path:
|
| 62 |
+
"""Get PARAMS_DIR, creating default if not set."""
|
| 63 |
+
global PARAMS_DIR
|
| 64 |
+
if PARAMS_DIR is None:
|
| 65 |
+
run_dir = _get_run_dir()
|
| 66 |
+
PARAMS_DIR = run_dir / "params"
|
| 67 |
+
PARAMS_DIR.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
return PARAMS_DIR
|
| 69 |
+
|
| 70 |
|
| 71 |
+
def _get_cases_parquet() -> Path:
|
| 72 |
+
"""Get CASES_CLEAN_PARQUET path."""
|
| 73 |
+
global CASES_CLEAN_PARQUET
|
| 74 |
+
if CASES_CLEAN_PARQUET is None:
|
| 75 |
+
CASES_CLEAN_PARQUET = _get_run_dir() / "cases_clean.parquet"
|
| 76 |
+
return CASES_CLEAN_PARQUET
|
| 77 |
|
| 78 |
+
|
| 79 |
+
def _get_hearings_parquet() -> Path:
|
| 80 |
+
"""Get HEARINGS_CLEAN_PARQUET path."""
|
| 81 |
+
global HEARINGS_CLEAN_PARQUET
|
| 82 |
+
if HEARINGS_CLEAN_PARQUET is None:
|
| 83 |
+
HEARINGS_CLEAN_PARQUET = _get_run_dir() / "hearings_clean.parquet"
|
| 84 |
+
return HEARINGS_CLEAN_PARQUET
|
| 85 |
|
| 86 |
# -------------------------------------------------------------------
|
| 87 |
# Null tokens and canonicalisation
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def copy_to_versioned(filename: str) -> None:
|
| 93 |
+
"""Deprecated: No longer needed with OutputManager."""
|
| 94 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
def write_metadata(meta: dict) -> None:
|
| 98 |
"""Write run metadata into RUN_DIR/metadata.json."""
|
| 99 |
+
run_dir = _get_run_dir()
|
| 100 |
+
meta_path = run_dir / "metadata.json"
|
| 101 |
try:
|
| 102 |
with open(meta_path, "w", encoding="utf-8") as f:
|
| 103 |
json.dump(meta, f, indent=2, default=str)
|
| 104 |
except Exception as e:
|
| 105 |
print(f"[WARN] Metadata export error: {e}")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def safe_write_figure(fig, filename: str) -> None:
|
| 109 |
+
"""Write plotly figure to EDA figures directory.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
fig: Plotly figure object
|
| 113 |
+
filename: HTML filename (e.g., "1_case_type_distribution.html")
|
| 114 |
+
"""
|
| 115 |
+
run_dir = _get_run_dir()
|
| 116 |
+
output_path = run_dir / filename
|
| 117 |
+
try:
|
| 118 |
+
fig.write_html(str(output_path))
|
| 119 |
+
except Exception as e:
|
| 120 |
+
raise RuntimeError(f"Failed to write {filename} to {output_path}: {e}")
|
|
@@ -13,7 +13,7 @@ Inputs:
|
|
| 13 |
- Cleaned Parquet from eda_load_clean.
|
| 14 |
|
| 15 |
Outputs:
|
| 16 |
-
- Interactive HTML plots in FIGURES_DIR and versioned copies in
|
| 17 |
- Some CSV summaries (e.g., stage_duration.csv, transitions.csv, monthly_anomalies.csv).
|
| 18 |
"""
|
| 19 |
|
|
@@ -25,19 +25,19 @@ import plotly.graph_objects as go
|
|
| 25 |
import plotly.io as pio
|
| 26 |
import polars as pl
|
| 27 |
from src.eda_config import (
|
| 28 |
-
|
| 29 |
FIGURES_DIR,
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
)
|
| 34 |
|
| 35 |
pio.renderers.default = "browser"
|
| 36 |
|
| 37 |
|
| 38 |
def load_cleaned():
|
| 39 |
-
cases = pl.read_parquet(
|
| 40 |
-
hearings = pl.read_parquet(
|
| 41 |
print("Loaded cleaned data for exploration")
|
| 42 |
print("Cases:", cases.shape, "Hearings:", hearings.shape)
|
| 43 |
return cases, hearings
|
|
@@ -58,9 +58,7 @@ def run_exploration() -> None:
|
|
| 58 |
title="Case Type Distribution",
|
| 59 |
)
|
| 60 |
fig1.update_layout(showlegend=False, xaxis_title="Case Type", yaxis_title="Number of Cases")
|
| 61 |
-
|
| 62 |
-
fig1.write_html(str(FIGURES_DIR / f1))
|
| 63 |
-
copy_to_versioned(f1)
|
| 64 |
|
| 65 |
# --------------------------------------------------
|
| 66 |
# 2. Filing Trends by Year
|
|
@@ -73,8 +71,7 @@ def run_exploration() -> None:
|
|
| 73 |
fig2.update_traces(line_color="royalblue")
|
| 74 |
fig2.update_layout(xaxis=dict(rangeslider=dict(visible=True)))
|
| 75 |
f2 = "2_cases_filed_by_year.html"
|
| 76 |
-
fig2
|
| 77 |
-
copy_to_versioned(f2)
|
| 78 |
|
| 79 |
# --------------------------------------------------
|
| 80 |
# 3. Disposal Duration Distribution
|
|
@@ -89,8 +86,7 @@ def run_exploration() -> None:
|
|
| 89 |
)
|
| 90 |
fig3.update_layout(xaxis_title="Days", yaxis_title="Cases")
|
| 91 |
f3 = "3_disposal_time_distribution.html"
|
| 92 |
-
fig3
|
| 93 |
-
copy_to_versioned(f3)
|
| 94 |
|
| 95 |
# --------------------------------------------------
|
| 96 |
# 4. Hearings vs Disposal Time
|
|
@@ -106,8 +102,7 @@ def run_exploration() -> None:
|
|
| 106 |
)
|
| 107 |
fig4.update_traces(marker=dict(size=6, opacity=0.7))
|
| 108 |
f4 = "4_hearings_vs_disposal.html"
|
| 109 |
-
fig4
|
| 110 |
-
copy_to_versioned(f4)
|
| 111 |
|
| 112 |
# --------------------------------------------------
|
| 113 |
# 5. Boxplot by Case Type
|
|
@@ -121,8 +116,7 @@ def run_exploration() -> None:
|
|
| 121 |
)
|
| 122 |
fig5.update_layout(showlegend=False)
|
| 123 |
f5 = "5_box_disposal_by_type.html"
|
| 124 |
-
fig5
|
| 125 |
-
copy_to_versioned(f5)
|
| 126 |
|
| 127 |
# --------------------------------------------------
|
| 128 |
# 6. Stage Frequency
|
|
@@ -139,8 +133,7 @@ def run_exploration() -> None:
|
|
| 139 |
)
|
| 140 |
fig6.update_layout(showlegend=False, xaxis_title="Stage", yaxis_title="Count")
|
| 141 |
f6 = "6_stage_frequency.html"
|
| 142 |
-
fig6
|
| 143 |
-
copy_to_versioned(f6)
|
| 144 |
|
| 145 |
# --------------------------------------------------
|
| 146 |
# 7. Gap median by case type
|
|
@@ -154,8 +147,7 @@ def run_exploration() -> None:
|
|
| 154 |
title="Median Hearing Gap by Case Type",
|
| 155 |
)
|
| 156 |
fg = "9_gap_median_by_type.html"
|
| 157 |
-
fig_gap
|
| 158 |
-
copy_to_versioned(fg)
|
| 159 |
|
| 160 |
# --------------------------------------------------
|
| 161 |
# 8. Stage transitions & bottleneck plot
|
|
@@ -219,7 +211,7 @@ def run_exploration() -> None:
|
|
| 219 |
<= pl.col("STAGE_TO").map_elements(lambda s: order_idx.get(s, 10))
|
| 220 |
).sort("N", descending=True)
|
| 221 |
|
| 222 |
-
transitions.write_csv(
|
| 223 |
|
| 224 |
runs = (
|
| 225 |
h_stage.with_columns(
|
|
@@ -256,7 +248,7 @@ def run_exploration() -> None:
|
|
| 256 |
)
|
| 257 |
.sort("RUN_MEDIAN_DAYS", descending=True)
|
| 258 |
)
|
| 259 |
-
stage_duration.write_csv(
|
| 260 |
|
| 261 |
# Sankey
|
| 262 |
try:
|
|
@@ -284,8 +276,7 @@ def run_exploration() -> None:
|
|
| 284 |
)
|
| 285 |
sankey.update_layout(title_text="Stage Transition Sankey (Ordered)")
|
| 286 |
f10 = "10_stage_transition_sankey.html"
|
| 287 |
-
sankey
|
| 288 |
-
copy_to_versioned(f10)
|
| 289 |
except Exception as e:
|
| 290 |
print("Sankey error:", e)
|
| 291 |
|
|
@@ -301,8 +292,7 @@ def run_exploration() -> None:
|
|
| 301 |
title="Stage Bottleneck Impact (Median Days x Runs)",
|
| 302 |
)
|
| 303 |
fb = "15_bottleneck_impact.html"
|
| 304 |
-
fig_b
|
| 305 |
-
copy_to_versioned(fb)
|
| 306 |
except Exception as e:
|
| 307 |
print("Bottleneck plot error:", e)
|
| 308 |
|
|
@@ -321,7 +311,7 @@ def run_exploration() -> None:
|
|
| 321 |
.with_columns(pl.date(pl.col("Y"), pl.col("M"), pl.lit(1)).alias("YM"))
|
| 322 |
)
|
| 323 |
monthly_listings = m_hear.group_by("YM").agg(pl.len().alias("N_HEARINGS")).sort("YM")
|
| 324 |
-
monthly_listings.write_csv(
|
| 325 |
|
| 326 |
try:
|
| 327 |
fig_m = px.line(
|
|
@@ -332,8 +322,7 @@ def run_exploration() -> None:
|
|
| 332 |
)
|
| 333 |
fig_m.update_layout(yaxis=dict(tickformat=",d"))
|
| 334 |
fm = "11_monthly_hearings.html"
|
| 335 |
-
fig_m
|
| 336 |
-
copy_to_versioned(fm)
|
| 337 |
except Exception as e:
|
| 338 |
print("Monthly listings error:", e)
|
| 339 |
|
|
@@ -380,12 +369,11 @@ def run_exploration() -> None:
|
|
| 380 |
yaxis=dict(tickformat=",d"),
|
| 381 |
)
|
| 382 |
fw = "11b_monthly_waterfall.html"
|
| 383 |
-
fig_w
|
| 384 |
-
copy_to_versioned(fw)
|
| 385 |
|
| 386 |
ml_pd_out = ml_pd.copy()
|
| 387 |
ml_pd_out["YM"] = ml_pd_out["YM"].astype(str)
|
| 388 |
-
ml_pd_out.to_csv(
|
| 389 |
except Exception as e:
|
| 390 |
print("Monthly waterfall error:", e)
|
| 391 |
|
|
@@ -420,8 +408,7 @@ def run_exploration() -> None:
|
|
| 420 |
xaxis={"categoryorder": "total descending"}, yaxis=dict(tickformat=",d")
|
| 421 |
)
|
| 422 |
fj = "12_judge_day_load.html"
|
| 423 |
-
fig_j
|
| 424 |
-
copy_to_versioned(fj)
|
| 425 |
except Exception as e:
|
| 426 |
print("Judge workload error:", e)
|
| 427 |
|
|
@@ -447,8 +434,7 @@ def run_exploration() -> None:
|
|
| 447 |
xaxis={"categoryorder": "total descending"}, yaxis=dict(tickformat=",d")
|
| 448 |
)
|
| 449 |
fc = "12b_court_day_load.html"
|
| 450 |
-
fig_court
|
| 451 |
-
copy_to_versioned(fc)
|
| 452 |
except Exception as e:
|
| 453 |
print("Court workload error:", e)
|
| 454 |
|
|
@@ -488,7 +474,7 @@ def run_exploration() -> None:
|
|
| 488 |
.with_columns((pl.col("N") / pl.col("N").sum().over("CASE_TYPE")).alias("SHARE"))
|
| 489 |
.sort(["CASE_TYPE", "SHARE"], descending=[False, True])
|
| 490 |
)
|
| 491 |
-
tag_share.write_csv(
|
| 492 |
try:
|
| 493 |
fig_t = px.bar(
|
| 494 |
tag_share.to_pandas(),
|
|
@@ -499,8 +485,7 @@ def run_exploration() -> None:
|
|
| 499 |
barmode="stack",
|
| 500 |
)
|
| 501 |
ft = "14_purpose_tag_shares.html"
|
| 502 |
-
fig_t
|
| 503 |
-
copy_to_versioned(ft)
|
| 504 |
except Exception as e:
|
| 505 |
print("Purpose shares error:", e)
|
| 506 |
|
|
|
|
| 13 |
- Cleaned Parquet from eda_load_clean.
|
| 14 |
|
| 15 |
Outputs:
|
| 16 |
+
- Interactive HTML plots in FIGURES_DIR and versioned copies in _get_run_dir().
|
| 17 |
- Some CSV summaries (e.g., stage_duration.csv, transitions.csv, monthly_anomalies.csv).
|
| 18 |
"""
|
| 19 |
|
|
|
|
| 25 |
import plotly.io as pio
|
| 26 |
import polars as pl
|
| 27 |
from src.eda_config import (
|
| 28 |
+
_get_cases_parquet,
|
| 29 |
FIGURES_DIR,
|
| 30 |
+
_get_hearings_parquet,
|
| 31 |
+
_get_run_dir,
|
| 32 |
+
safe_write_figure,
|
| 33 |
)
|
| 34 |
|
| 35 |
pio.renderers.default = "browser"
|
| 36 |
|
| 37 |
|
| 38 |
def load_cleaned():
|
| 39 |
+
cases = pl.read_parquet(_get_cases_parquet())
|
| 40 |
+
hearings = pl.read_parquet(_get_hearings_parquet())
|
| 41 |
print("Loaded cleaned data for exploration")
|
| 42 |
print("Cases:", cases.shape, "Hearings:", hearings.shape)
|
| 43 |
return cases, hearings
|
|
|
|
| 58 |
title="Case Type Distribution",
|
| 59 |
)
|
| 60 |
fig1.update_layout(showlegend=False, xaxis_title="Case Type", yaxis_title="Number of Cases")
|
| 61 |
+
safe_write_figure(fig1, "1_case_type_distribution.html")
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# --------------------------------------------------
|
| 64 |
# 2. Filing Trends by Year
|
|
|
|
| 71 |
fig2.update_traces(line_color="royalblue")
|
| 72 |
fig2.update_layout(xaxis=dict(rangeslider=dict(visible=True)))
|
| 73 |
f2 = "2_cases_filed_by_year.html"
|
| 74 |
+
safe_write_figure(fig2, f2)
|
|
|
|
| 75 |
|
| 76 |
# --------------------------------------------------
|
| 77 |
# 3. Disposal Duration Distribution
|
|
|
|
| 86 |
)
|
| 87 |
fig3.update_layout(xaxis_title="Days", yaxis_title="Cases")
|
| 88 |
f3 = "3_disposal_time_distribution.html"
|
| 89 |
+
safe_write_figure(fig3, f3)
|
|
|
|
| 90 |
|
| 91 |
# --------------------------------------------------
|
| 92 |
# 4. Hearings vs Disposal Time
|
|
|
|
| 102 |
)
|
| 103 |
fig4.update_traces(marker=dict(size=6, opacity=0.7))
|
| 104 |
f4 = "4_hearings_vs_disposal.html"
|
| 105 |
+
safe_write_figure(fig4, f4)
|
|
|
|
| 106 |
|
| 107 |
# --------------------------------------------------
|
| 108 |
# 5. Boxplot by Case Type
|
|
|
|
| 116 |
)
|
| 117 |
fig5.update_layout(showlegend=False)
|
| 118 |
f5 = "5_box_disposal_by_type.html"
|
| 119 |
+
safe_write_figure(fig5, f5)
|
|
|
|
| 120 |
|
| 121 |
# --------------------------------------------------
|
| 122 |
# 6. Stage Frequency
|
|
|
|
| 133 |
)
|
| 134 |
fig6.update_layout(showlegend=False, xaxis_title="Stage", yaxis_title="Count")
|
| 135 |
f6 = "6_stage_frequency.html"
|
| 136 |
+
safe_write_figure(fig6, f6)
|
|
|
|
| 137 |
|
| 138 |
# --------------------------------------------------
|
| 139 |
# 7. Gap median by case type
|
|
|
|
| 147 |
title="Median Hearing Gap by Case Type",
|
| 148 |
)
|
| 149 |
fg = "9_gap_median_by_type.html"
|
| 150 |
+
safe_write_figure(fig_gap, fg)
|
|
|
|
| 151 |
|
| 152 |
# --------------------------------------------------
|
| 153 |
# 8. Stage transitions & bottleneck plot
|
|
|
|
| 211 |
<= pl.col("STAGE_TO").map_elements(lambda s: order_idx.get(s, 10))
|
| 212 |
).sort("N", descending=True)
|
| 213 |
|
| 214 |
+
transitions.write_csv(str(_get_run_dir() / "transitions.csv"))
|
| 215 |
|
| 216 |
runs = (
|
| 217 |
h_stage.with_columns(
|
|
|
|
| 248 |
)
|
| 249 |
.sort("RUN_MEDIAN_DAYS", descending=True)
|
| 250 |
)
|
| 251 |
+
stage_duration.write_csv(str(_get_run_dir() / "stage_duration.csv"))
|
| 252 |
|
| 253 |
# Sankey
|
| 254 |
try:
|
|
|
|
| 276 |
)
|
| 277 |
sankey.update_layout(title_text="Stage Transition Sankey (Ordered)")
|
| 278 |
f10 = "10_stage_transition_sankey.html"
|
| 279 |
+
safe_write_figure(sankey, f10)
|
|
|
|
| 280 |
except Exception as e:
|
| 281 |
print("Sankey error:", e)
|
| 282 |
|
|
|
|
| 292 |
title="Stage Bottleneck Impact (Median Days x Runs)",
|
| 293 |
)
|
| 294 |
fb = "15_bottleneck_impact.html"
|
| 295 |
+
safe_write_figure(fig_b, fb)
|
|
|
|
| 296 |
except Exception as e:
|
| 297 |
print("Bottleneck plot error:", e)
|
| 298 |
|
|
|
|
| 311 |
.with_columns(pl.date(pl.col("Y"), pl.col("M"), pl.lit(1)).alias("YM"))
|
| 312 |
)
|
| 313 |
monthly_listings = m_hear.group_by("YM").agg(pl.len().alias("N_HEARINGS")).sort("YM")
|
| 314 |
+
monthly_listings.write_csv(str(_get_run_dir() / "monthly_hearings.csv"))
|
| 315 |
|
| 316 |
try:
|
| 317 |
fig_m = px.line(
|
|
|
|
| 322 |
)
|
| 323 |
fig_m.update_layout(yaxis=dict(tickformat=",d"))
|
| 324 |
fm = "11_monthly_hearings.html"
|
| 325 |
+
safe_write_figure(fig_m, fm)
|
|
|
|
| 326 |
except Exception as e:
|
| 327 |
print("Monthly listings error:", e)
|
| 328 |
|
|
|
|
| 369 |
yaxis=dict(tickformat=",d"),
|
| 370 |
)
|
| 371 |
fw = "11b_monthly_waterfall.html"
|
| 372 |
+
safe_write_figure(fig_w, fw)
|
|
|
|
| 373 |
|
| 374 |
ml_pd_out = ml_pd.copy()
|
| 375 |
ml_pd_out["YM"] = ml_pd_out["YM"].astype(str)
|
| 376 |
+
ml_pd_out.to_csv(str(_get_run_dir() / "monthly_anomalies.csv"), index=False)
|
| 377 |
except Exception as e:
|
| 378 |
print("Monthly waterfall error:", e)
|
| 379 |
|
|
|
|
| 408 |
xaxis={"categoryorder": "total descending"}, yaxis=dict(tickformat=",d")
|
| 409 |
)
|
| 410 |
fj = "12_judge_day_load.html"
|
| 411 |
+
safe_write_figure(fig_j, fj)
|
|
|
|
| 412 |
except Exception as e:
|
| 413 |
print("Judge workload error:", e)
|
| 414 |
|
|
|
|
| 434 |
xaxis={"categoryorder": "total descending"}, yaxis=dict(tickformat=",d")
|
| 435 |
)
|
| 436 |
fc = "12b_court_day_load.html"
|
| 437 |
+
safe_write_figure(fig_court, fc)
|
|
|
|
| 438 |
except Exception as e:
|
| 439 |
print("Court workload error:", e)
|
| 440 |
|
|
|
|
| 474 |
.with_columns((pl.col("N") / pl.col("N").sum().over("CASE_TYPE")).alias("SHARE"))
|
| 475 |
.sort(["CASE_TYPE", "SHARE"], descending=[False, True])
|
| 476 |
)
|
| 477 |
+
tag_share.write_csv(str(_get_run_dir() / "purpose_tag_shares.csv"))
|
| 478 |
try:
|
| 479 |
fig_t = px.bar(
|
| 480 |
tag_share.to_pandas(),
|
|
|
|
| 485 |
barmode="stack",
|
| 486 |
)
|
| 487 |
ft = "14_purpose_tag_shares.html"
|
| 488 |
+
safe_write_figure(fig_t, ft)
|
|
|
|
| 489 |
except Exception as e:
|
| 490 |
print("Purpose shares error:", e)
|
| 491 |
|
|
@@ -13,9 +13,9 @@ from datetime import timedelta
|
|
| 13 |
import polars as pl
|
| 14 |
import duckdb
|
| 15 |
from src.eda_config import (
|
| 16 |
-
|
| 17 |
DUCKDB_FILE,
|
| 18 |
-
|
| 19 |
NULL_TOKENS,
|
| 20 |
RUN_TS,
|
| 21 |
VERSION,
|
|
@@ -224,10 +224,10 @@ def clean_and_augment(
|
|
| 224 |
|
| 225 |
|
| 226 |
def save_clean(cases: pl.DataFrame, hearings: pl.DataFrame) -> None:
|
| 227 |
-
cases.write_parquet(
|
| 228 |
-
hearings.write_parquet(
|
| 229 |
-
print(f"Saved cleaned cases -> {
|
| 230 |
-
print(f"Saved cleaned hearings -> {
|
| 231 |
|
| 232 |
meta = {
|
| 233 |
"version": VERSION,
|
|
|
|
| 13 |
import polars as pl
|
| 14 |
import duckdb
|
| 15 |
from src.eda_config import (
|
| 16 |
+
_get_cases_parquet,
|
| 17 |
DUCKDB_FILE,
|
| 18 |
+
_get_hearings_parquet,
|
| 19 |
NULL_TOKENS,
|
| 20 |
RUN_TS,
|
| 21 |
VERSION,
|
|
|
|
| 224 |
|
| 225 |
|
| 226 |
def save_clean(cases: pl.DataFrame, hearings: pl.DataFrame) -> None:
|
| 227 |
+
cases.write_parquet(str(_get_cases_parquet()))
|
| 228 |
+
hearings.write_parquet(str(_get_hearings_parquet()))
|
| 229 |
+
print(f"Saved cleaned cases -> {str(_get_cases_parquet())}")
|
| 230 |
+
print(f"Saved cleaned hearings -> {str(_get_hearings_parquet())}")
|
| 231 |
|
| 232 |
meta = {
|
| 233 |
"version": VERSION,
|
|
@@ -8,7 +8,7 @@ Responsibilities:
|
|
| 8 |
- Entropy of stage transitions (predictability).
|
| 9 |
- Case-type summary stats (disposal, hearing counts, gaps).
|
| 10 |
- Readiness score and alert flags per case.
|
| 11 |
-
- Export JSON/CSV parameter files into
|
| 12 |
"""
|
| 13 |
|
| 14 |
import json
|
|
@@ -16,15 +16,15 @@ from datetime import timedelta
|
|
| 16 |
|
| 17 |
import polars as pl
|
| 18 |
from src.eda_config import (
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
)
|
| 23 |
|
| 24 |
|
| 25 |
def load_cleaned():
|
| 26 |
-
cases = pl.read_parquet(
|
| 27 |
-
hearings = pl.read_parquet(
|
| 28 |
return cases, hearings
|
| 29 |
|
| 30 |
|
|
@@ -94,14 +94,14 @@ def extract_parameters() -> None:
|
|
| 94 |
<= pl.col("STAGE_TO").map_elements(lambda s: order_idx.get(s, 10))
|
| 95 |
).sort("N", descending=True)
|
| 96 |
|
| 97 |
-
transitions.write_csv(
|
| 98 |
|
| 99 |
# Probabilities per STAGE_FROM
|
| 100 |
row_tot = transitions.group_by("STAGE_FROM").agg(pl.col("N").sum().alias("row_n"))
|
| 101 |
trans_probs = transitions.join(row_tot, on="STAGE_FROM").with_columns(
|
| 102 |
(pl.col("N") / pl.col("row_n")).alias("p")
|
| 103 |
)
|
| 104 |
-
trans_probs.write_csv(
|
| 105 |
|
| 106 |
# Entropy of transitions
|
| 107 |
ent = (
|
|
@@ -109,7 +109,7 @@ def extract_parameters() -> None:
|
|
| 109 |
.agg((-(pl.col("p") * pl.col("p").log()).sum()).alias("entropy"))
|
| 110 |
.sort("entropy", descending=True)
|
| 111 |
)
|
| 112 |
-
ent.write_csv(
|
| 113 |
|
| 114 |
# Stage residence (runs)
|
| 115 |
runs = (
|
|
@@ -147,7 +147,7 @@ def extract_parameters() -> None:
|
|
| 147 |
)
|
| 148 |
.sort("RUN_MEDIAN_DAYS", descending=True)
|
| 149 |
)
|
| 150 |
-
stage_duration.write_csv(
|
| 151 |
|
| 152 |
# --------------------------------------------------
|
| 153 |
# 2. Court capacity (cases per courtroom per day)
|
|
@@ -169,13 +169,13 @@ def extract_parameters() -> None:
|
|
| 169 |
)
|
| 170 |
.sort("slots_median", descending=True)
|
| 171 |
)
|
| 172 |
-
cap_stats.write_csv(
|
| 173 |
# simple global aggregate
|
| 174 |
capacity_stats = {
|
| 175 |
"slots_median_global": float(cap["heard_count"].median()),
|
| 176 |
"slots_p90_global": float(cap["heard_count"].quantile(0.9)),
|
| 177 |
}
|
| 178 |
-
with open(
|
| 179 |
json.dump(capacity_stats, f, indent=2)
|
| 180 |
|
| 181 |
# --------------------------------------------------
|
|
@@ -245,7 +245,7 @@ def extract_parameters() -> None:
|
|
| 245 |
)
|
| 246 |
.sort(["Remappedstages", "casetype"])
|
| 247 |
)
|
| 248 |
-
outcome_stage.write_csv(
|
| 249 |
|
| 250 |
# --------------------------------------------------
|
| 251 |
# 4. Case-type summary and correlations
|
|
@@ -263,13 +263,13 @@ def extract_parameters() -> None:
|
|
| 263 |
)
|
| 264 |
.sort("n_cases", descending=True)
|
| 265 |
)
|
| 266 |
-
by_type.write_csv(
|
| 267 |
|
| 268 |
# Correlations for a quick diagnostic
|
| 269 |
corr_cols = ["DISPOSALTIME_ADJ", "N_HEARINGS", "GAP_MEDIAN"]
|
| 270 |
corr_df = cases.select(corr_cols).to_pandas()
|
| 271 |
corr = corr_df.corr(method="spearman")
|
| 272 |
-
corr.to_csv(
|
| 273 |
|
| 274 |
# --------------------------------------------------
|
| 275 |
# 5. Readiness score and alerts
|
|
@@ -364,7 +364,7 @@ def extract_parameters() -> None:
|
|
| 364 |
"ALERT_LONG_GAP",
|
| 365 |
]
|
| 366 |
feature_cols_existing = [c for c in feature_cols if c in cases.columns]
|
| 367 |
-
cases.select(feature_cols_existing).write_csv(
|
| 368 |
|
| 369 |
# Simple age funnel
|
| 370 |
if {"DATE_FILED", "DECISION_DATE"}.issubset(cases.columns):
|
|
@@ -388,12 +388,12 @@ def extract_parameters() -> None:
|
|
| 388 |
.agg(pl.len().alias("N"))
|
| 389 |
.sort("AGE_BUCKET")
|
| 390 |
)
|
| 391 |
-
age_funnel.write_csv(
|
| 392 |
|
| 393 |
|
| 394 |
def run_parameter_export() -> None:
|
| 395 |
extract_parameters()
|
| 396 |
-
print("Parameter extraction complete. Files in:",
|
| 397 |
|
| 398 |
|
| 399 |
if __name__ == "__main__":
|
|
|
|
| 8 |
- Entropy of stage transitions (predictability).
|
| 9 |
- Case-type summary stats (disposal, hearing counts, gaps).
|
| 10 |
- Readiness score and alert flags per case.
|
| 11 |
+
- Export JSON/CSV parameter files into _get_params_dir().
|
| 12 |
"""
|
| 13 |
|
| 14 |
import json
|
|
|
|
| 16 |
|
| 17 |
import polars as pl
|
| 18 |
from src.eda_config import (
|
| 19 |
+
_get_cases_parquet,
|
| 20 |
+
_get_hearings_parquet,
|
| 21 |
+
_get_params_dir,
|
| 22 |
)
|
| 23 |
|
| 24 |
|
| 25 |
def load_cleaned():
|
| 26 |
+
cases = pl.read_parquet(_get_cases_parquet())
|
| 27 |
+
hearings = pl.read_parquet(_get_hearings_parquet())
|
| 28 |
return cases, hearings
|
| 29 |
|
| 30 |
|
|
|
|
| 94 |
<= pl.col("STAGE_TO").map_elements(lambda s: order_idx.get(s, 10))
|
| 95 |
).sort("N", descending=True)
|
| 96 |
|
| 97 |
+
transitions.write_csv(str(_get_params_dir() / "stage_transitions.csv"))
|
| 98 |
|
| 99 |
# Probabilities per STAGE_FROM
|
| 100 |
row_tot = transitions.group_by("STAGE_FROM").agg(pl.col("N").sum().alias("row_n"))
|
| 101 |
trans_probs = transitions.join(row_tot, on="STAGE_FROM").with_columns(
|
| 102 |
(pl.col("N") / pl.col("row_n")).alias("p")
|
| 103 |
)
|
| 104 |
+
trans_probs.write_csv(str(_get_params_dir() / "stage_transition_probs.csv"))
|
| 105 |
|
| 106 |
# Entropy of transitions
|
| 107 |
ent = (
|
|
|
|
| 109 |
.agg((-(pl.col("p") * pl.col("p").log()).sum()).alias("entropy"))
|
| 110 |
.sort("entropy", descending=True)
|
| 111 |
)
|
| 112 |
+
ent.write_csv(str(_get_params_dir() / "stage_transition_entropy.csv"))
|
| 113 |
|
| 114 |
# Stage residence (runs)
|
| 115 |
runs = (
|
|
|
|
| 147 |
)
|
| 148 |
.sort("RUN_MEDIAN_DAYS", descending=True)
|
| 149 |
)
|
| 150 |
+
stage_duration.write_csv(str(_get_params_dir() / "stage_duration.csv"))
|
| 151 |
|
| 152 |
# --------------------------------------------------
|
| 153 |
# 2. Court capacity (cases per courtroom per day)
|
|
|
|
| 169 |
)
|
| 170 |
.sort("slots_median", descending=True)
|
| 171 |
)
|
| 172 |
+
cap_stats.write_csv(str(_get_params_dir() / "court_capacity_stats.csv"))
|
| 173 |
# simple global aggregate
|
| 174 |
capacity_stats = {
|
| 175 |
"slots_median_global": float(cap["heard_count"].median()),
|
| 176 |
"slots_p90_global": float(cap["heard_count"].quantile(0.9)),
|
| 177 |
}
|
| 178 |
+
with open(str(_get_params_dir() / "court_capacity_global.json"), "w") as f:
|
| 179 |
json.dump(capacity_stats, f, indent=2)
|
| 180 |
|
| 181 |
# --------------------------------------------------
|
|
|
|
| 245 |
)
|
| 246 |
.sort(["Remappedstages", "casetype"])
|
| 247 |
)
|
| 248 |
+
outcome_stage.write_csv(str(_get_params_dir() / "adjournment_proxies.csv"))
|
| 249 |
|
| 250 |
# --------------------------------------------------
|
| 251 |
# 4. Case-type summary and correlations
|
|
|
|
| 263 |
)
|
| 264 |
.sort("n_cases", descending=True)
|
| 265 |
)
|
| 266 |
+
by_type.write_csv(str(_get_params_dir() / "case_type_summary.csv"))
|
| 267 |
|
| 268 |
# Correlations for a quick diagnostic
|
| 269 |
corr_cols = ["DISPOSALTIME_ADJ", "N_HEARINGS", "GAP_MEDIAN"]
|
| 270 |
corr_df = cases.select(corr_cols).to_pandas()
|
| 271 |
corr = corr_df.corr(method="spearman")
|
| 272 |
+
corr.to_csv(str(_get_params_dir() / "correlations_spearman.csv"))
|
| 273 |
|
| 274 |
# --------------------------------------------------
|
| 275 |
# 5. Readiness score and alerts
|
|
|
|
| 364 |
"ALERT_LONG_GAP",
|
| 365 |
]
|
| 366 |
feature_cols_existing = [c for c in feature_cols if c in cases.columns]
|
| 367 |
+
cases.select(feature_cols_existing).write_csv(str(_get_params_dir() / "cases_features.csv"))
|
| 368 |
|
| 369 |
# Simple age funnel
|
| 370 |
if {"DATE_FILED", "DECISION_DATE"}.issubset(cases.columns):
|
|
|
|
| 388 |
.agg(pl.len().alias("N"))
|
| 389 |
.sort("AGE_BUCKET")
|
| 390 |
)
|
| 391 |
+
age_funnel.write_csv(str(_get_params_dir() / "age_funnel.csv"))
|
| 392 |
|
| 393 |
|
| 394 |
def run_parameter_export() -> None:
|
| 395 |
extract_parameters()
|
| 396 |
+
print("Parameter extraction complete. Files in:", _get_params_dir().resolve())
|
| 397 |
|
| 398 |
|
| 399 |
if __name__ == "__main__":
|