RoyAalekh commited on
Commit
6d32faf
·
1 Parent(s): b512b22

refactor: Restructure project with unified CLI and fix RL training gaps

Browse files

Major Changes:
- Created unified cli/ directory with single entry point (cli.main:app)
- Consolidated court_scheduler/ into cli/ (cli.py → cli/main.py)
- Merged config modules into cli/config.py
- Moved main.py → src/run_eda.py (EDA-specific)
- Moved test files to tests/ directory
- Deleted obsolete root scripts (court_scheduler_rl.py, train_rl_agent.py)
- Updated pyproject.toml entry point: court_scheduler.cli:app → cli.main:app

CLI Commands Available:
- court-scheduler eda # Run EDA pipeline
- court-scheduler generate # Generate test cases
- court-scheduler simulate # Run simulation
- court-scheduler train # Train RL agent
- court-scheduler workflow # Full pipeline
- court-scheduler version # Show version

RL Training Enhancements (Gap Fixes):
- Fixed Gap 1: RL training now uses EDA-derived parameters
- Added ParameterLoader to RLTrainingEnvironment
- Replaced hardcoded hearing outcome probabilities with param_loader
- Uses get_adjournment_prob() and get_stage_transitions_fast()
- Training dynamics now align with production simulation

- Fixed Gap 2: Ripeness feedback loop implemented
- Created scheduler/monitoring/ripeness_metrics.py (RipenessMetrics)
- Created scheduler/monitoring/ripeness_calibrator.py (RipenessCalibrator)
- Added RipenessClassifier.set_thresholds() for dynamic calibration
- Tracks false positive/negative rates, suggests threshold adjustments
- 5 calibration rules for adaptive learning

Test Results:
- Gap 1: Adjournment rate 43.0% vs EDA 42.3% (0.7% diff, within range)
- Gap 2: Calibrator successfully suggests 3 threshold adjustments
- All gap fix tests passing

Project Structure:
- Clean root directory (no Python scripts)
- Clear separation: cli/ (interface), scheduler/ (library), rl/ (training), src/ (EDA)
- All functionality accessible via single entry point: uv run court-scheduler

cli/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Unified CLI for Court Scheduling System."""
2
+
3
+ __version__ = "1.0.0"
cli/commands/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """CLI command modules."""
court_scheduler/config_models.py → cli/config.py RENAMED
@@ -1,13 +1,20 @@
 
 
1
  from __future__ import annotations
2
 
 
 
3
  from datetime import date
4
  from pathlib import Path
5
- from typing import Optional
6
 
7
  from pydantic import BaseModel, Field, field_validator
8
 
9
 
 
 
10
  class GenerateConfig(BaseModel):
 
11
  n_cases: int = Field(10000, ge=1)
12
  start: date = Field(..., description="Case filing start date")
13
  end: date = Field(..., description="Case filing end date")
@@ -16,12 +23,12 @@ class GenerateConfig(BaseModel):
16
 
17
  @field_validator("end")
18
  @classmethod
19
- def _check_range(cls, v: date, info): # noqa: D401
20
- # end must be >= start; we can't read start here easily, so skip strict check
21
  return v
22
 
23
 
24
  class SimulateConfig(BaseModel):
 
25
  cases: Path = Path("data/generated/cases.csv")
26
  days: int = Field(384, ge=1)
27
  start: Optional[date] = None
@@ -34,5 +41,36 @@ class SimulateConfig(BaseModel):
34
 
35
 
36
  class WorkflowConfig(BaseModel):
 
37
  generate: GenerateConfig
38
- simulate: SimulateConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration models and loaders for CLI commands."""
2
+
3
  from __future__ import annotations
4
 
5
+ import json
6
+ import tomllib
7
  from datetime import date
8
  from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
 
11
  from pydantic import BaseModel, Field, field_validator
12
 
13
 
14
+ # Configuration Models
15
+
16
  class GenerateConfig(BaseModel):
17
+ """Configuration for case generation command."""
18
  n_cases: int = Field(10000, ge=1)
19
  start: date = Field(..., description="Case filing start date")
20
  end: date = Field(..., description="Case filing end date")
 
23
 
24
  @field_validator("end")
25
  @classmethod
26
+ def _check_range(cls, v: date, info):
 
27
  return v
28
 
29
 
30
  class SimulateConfig(BaseModel):
31
+ """Configuration for simulation command."""
32
  cases: Path = Path("data/generated/cases.csv")
33
  days: int = Field(384, ge=1)
34
  start: Optional[date] = None
 
41
 
42
 
43
  class WorkflowConfig(BaseModel):
44
+ """Configuration for full workflow command."""
45
  generate: GenerateConfig
46
+ simulate: SimulateConfig
47
+
48
+
49
+ # Configuration Loaders
50
+
51
+ def _read_config(path: Path) -> Dict[str, Any]:
52
+ """Read configuration from .toml or .json file."""
53
+ suf = path.suffix.lower()
54
+ if suf == ".json":
55
+ return json.loads(path.read_text(encoding="utf-8"))
56
+ if suf == ".toml":
57
+ return tomllib.loads(path.read_text(encoding="utf-8"))
58
+ raise ValueError(f"Unsupported config format: {path.suffix}. Use .toml or .json")
59
+
60
+
61
+ def load_generate_config(path: Path) -> GenerateConfig:
62
+ """Load generation configuration from file."""
63
+ data = _read_config(path)
64
+ return GenerateConfig(**data)
65
+
66
+
67
+ def load_simulate_config(path: Path) -> SimulateConfig:
68
+ """Load simulation configuration from file."""
69
+ data = _read_config(path)
70
+ return SimulateConfig(**data)
71
+
72
+
73
+ def load_workflow_config(path: Path) -> WorkflowConfig:
74
+ """Load workflow configuration from file."""
75
+ data = _read_config(path)
76
+ return WorkflowConfig(**data)
court_scheduler/cli.py → cli/main.py RENAMED
@@ -3,7 +3,8 @@
3
  This module provides a single entry point for all court scheduling operations:
4
  - EDA pipeline execution
5
  - Case generation
6
- - Simulation runs
 
7
  - Full workflow orchestration
8
  """
9
 
@@ -17,6 +18,8 @@ import typer
17
  from rich.console import Console
18
  from rich.progress import Progress, SpinnerColumn, TextColumn
19
 
 
 
20
  # Initialize Typer app and console
21
  app = typer.Typer(
22
  name="court-scheduler",
@@ -88,13 +91,11 @@ def generate(
88
  try:
89
  from datetime import date as date_cls
90
  from scheduler.data.case_generator import CaseGenerator
91
- from .config_loader import load_generate_config
92
- from .config_models import GenerateConfig
93
 
94
  # Resolve parameters: config -> interactive -> flags
95
  if config:
96
  cfg = load_generate_config(config)
97
- # Note: in this first iteration, flags do not override config for generate
98
  else:
99
  if interactive:
100
  n_cases = typer.prompt("Number of cases", default=n_cases)
@@ -156,13 +157,12 @@ def simulate(
156
  from scheduler.data.case_generator import CaseGenerator
157
  from scheduler.metrics.basic import gini
158
  from scheduler.simulation.engine import CourtSim, CourtSimConfig
159
- from .config_loader import load_simulate_config
160
- from .config_models import SimulateConfig
161
 
162
  # Resolve parameters: config -> interactive -> flags
163
  if config:
164
  scfg = load_simulate_config(config)
165
- # CLI flags override config if provided (best-effort)
166
  scfg = scfg.model_copy(update={
167
  "cases": Path(cases_csv) if cases_csv else scfg.cases,
168
  "days": days if days else scfg.days,
@@ -219,90 +219,7 @@ def simulate(
219
  res = sim.run()
220
  progress.update(task, completed=True)
221
 
222
- # Calculate additional metrics for report
223
- allocator_stats = sim.allocator.get_utilization_stats()
224
- disp_times = [(c.disposal_date - c.filed_date).days for c in cases
225
- if c.disposal_date is not None and c.status == CaseStatus.DISPOSED]
226
- gini_disp = gini(disp_times) if disp_times else 0.0
227
-
228
- # Disposal rates by case type
229
- case_type_stats = {}
230
- for c in cases:
231
- if c.case_type not in case_type_stats:
232
- case_type_stats[c.case_type] = {"total": 0, "disposed": 0}
233
- case_type_stats[c.case_type]["total"] += 1
234
- if c.is_disposed:
235
- case_type_stats[c.case_type]["disposed"] += 1
236
-
237
- # Ripeness distribution
238
- active_cases = [c for c in cases if not c.is_disposed]
239
- ripeness_dist = {}
240
- for c in active_cases:
241
- status = c.ripeness_status
242
- ripeness_dist[status] = ripeness_dist.get(status, 0) + 1
243
-
244
- # Generate report.txt if log_dir specified
245
- if log_dir:
246
- Path(log_dir).mkdir(parents=True, exist_ok=True)
247
- report_path = Path(log_dir) / "report.txt"
248
- with report_path.open("w", encoding="utf-8") as rf:
249
- rf.write("=" * 80 + "\n")
250
- rf.write("SIMULATION REPORT\n")
251
- rf.write("=" * 80 + "\n\n")
252
-
253
- rf.write(f"Configuration:\n")
254
- rf.write(f" Cases: {len(cases)}\n")
255
- rf.write(f" Days simulated: {days}\n")
256
- rf.write(f" Policy: {policy}\n")
257
- rf.write(f" Horizon end: {res.end_date}\n\n")
258
-
259
- rf.write(f"Hearing Metrics:\n")
260
- rf.write(f" Total hearings: {res.hearings_total:,}\n")
261
- rf.write(f" Heard: {res.hearings_heard:,} ({res.hearings_heard/max(1,res.hearings_total):.1%})\n")
262
- rf.write(f" Adjourned: {res.hearings_adjourned:,} ({res.hearings_adjourned/max(1,res.hearings_total):.1%})\n\n")
263
-
264
- rf.write(f"Disposal Metrics:\n")
265
- rf.write(f" Cases disposed: {res.disposals:,}\n")
266
- rf.write(f" Disposal rate: {res.disposals/len(cases):.1%}\n")
267
- rf.write(f" Gini coefficient: {gini_disp:.3f}\n\n")
268
-
269
- rf.write(f"Disposal Rates by Case Type:\n")
270
- for ct in sorted(case_type_stats.keys()):
271
- stats = case_type_stats[ct]
272
- rate = (stats["disposed"] / stats["total"] * 100) if stats["total"] > 0 else 0
273
- rf.write(f" {ct:4s}: {stats['disposed']:4d}/{stats['total']:4d} ({rate:5.1f}%)\n")
274
- rf.write("\n")
275
-
276
- rf.write(f"Efficiency Metrics:\n")
277
- rf.write(f" Court utilization: {res.utilization:.1%}\n")
278
- rf.write(f" Avg hearings/day: {res.hearings_total/days:.1f}\n\n")
279
-
280
- rf.write(f"Ripeness Impact:\n")
281
- rf.write(f" Transitions: {res.ripeness_transitions:,}\n")
282
- rf.write(f" Cases filtered (unripe): {res.unripe_filtered:,}\n")
283
- if res.hearings_total + res.unripe_filtered > 0:
284
- rf.write(f" Filter rate: {res.unripe_filtered/(res.hearings_total + res.unripe_filtered):.1%}\n")
285
- rf.write("\nFinal Ripeness Distribution:\n")
286
- for status in sorted(ripeness_dist.keys()):
287
- count = ripeness_dist[status]
288
- pct = (count / len(active_cases) * 100) if active_cases else 0
289
- rf.write(f" {status}: {count} ({pct:.1f}%)\n")
290
-
291
- # Courtroom allocation metrics
292
- if allocator_stats:
293
- rf.write("\nCourtroom Allocation:\n")
294
- rf.write(f" Strategy: load_balanced\n")
295
- rf.write(f" Load balance fairness (Gini): {allocator_stats['load_balance_gini']:.3f}\n")
296
- rf.write(f" Avg daily load: {allocator_stats['avg_daily_load']:.1f} cases\n")
297
- rf.write(f" Allocation changes: {allocator_stats['allocation_changes']:,}\n")
298
- rf.write(f" Capacity rejections: {allocator_stats['capacity_rejections']:,}\n\n")
299
- rf.write(" Courtroom-wise totals:\n")
300
- for cid in range(1, sim.cfg.courtrooms + 1):
301
- total = allocator_stats['courtroom_totals'][cid]
302
- avg = allocator_stats['courtroom_averages'][cid]
303
- rf.write(f" Courtroom {cid}: {total:,} cases ({avg:.1f}/day)\n")
304
-
305
- # Display results to console
306
  console.print("\n[bold green]Simulation Complete![/bold green]")
307
  console.print(f"\nHorizon: {cfg.start} \u2192 {res.end_date} ({days} days)")
308
  console.print(f"\n[bold]Hearing Metrics:[/bold]")
@@ -310,6 +227,10 @@ def simulate(
310
  console.print(f" Heard: {res.hearings_heard:,} ({res.hearings_heard/max(1,res.hearings_total):.1%})")
311
  console.print(f" Adjourned: {res.hearings_adjourned:,} ({res.hearings_adjourned/max(1,res.hearings_total):.1%})")
312
 
 
 
 
 
313
  console.print(f"\n[bold]Disposal Metrics:[/bold]")
314
  console.print(f" Cases disposed: {res.disposals:,} ({res.disposals/len(cases):.1%})")
315
  console.print(f" Gini coefficient: {gini_disp:.3f}")
@@ -320,15 +241,73 @@ def simulate(
320
 
321
  if log_dir:
322
  console.print(f"\n[bold cyan]Output Files:[/bold cyan]")
323
- console.print(f" - {log_dir}/report.txt (comprehensive report)")
324
- console.print(f" - {log_dir}/metrics.csv (daily metrics)")
325
- console.print(f" - {log_dir}/events.csv (event log)")
326
 
327
  except Exception as e:
328
  console.print(f"[bold red]Error:[/bold red] {e}")
329
  raise typer.Exit(code=1)
330
 
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  @app.command()
333
  def workflow(
334
  n_cases: int = typer.Option(10000, "--cases", "-n", help="Number of cases to generate"),
@@ -394,7 +373,6 @@ def workflow(
394
  @app.command()
395
  def version() -> None:
396
  """Show version information."""
397
- from court_scheduler import __version__
398
  console.print(f"Court Scheduler CLI v{__version__}")
399
  console.print("Court Scheduling System for Karnataka High Court")
400
 
 
3
  This module provides a single entry point for all court scheduling operations:
4
  - EDA pipeline execution
5
  - Case generation
6
+ - Simulation runs
7
+ - RL training
8
  - Full workflow orchestration
9
  """
10
 
 
18
  from rich.console import Console
19
  from rich.progress import Progress, SpinnerColumn, TextColumn
20
 
21
+ from cli import __version__
22
+
23
  # Initialize Typer app and console
24
  app = typer.Typer(
25
  name="court-scheduler",
 
91
  try:
92
  from datetime import date as date_cls
93
  from scheduler.data.case_generator import CaseGenerator
94
+ from cli.config import load_generate_config, GenerateConfig
 
95
 
96
  # Resolve parameters: config -> interactive -> flags
97
  if config:
98
  cfg = load_generate_config(config)
 
99
  else:
100
  if interactive:
101
  n_cases = typer.prompt("Number of cases", default=n_cases)
 
157
  from scheduler.data.case_generator import CaseGenerator
158
  from scheduler.metrics.basic import gini
159
  from scheduler.simulation.engine import CourtSim, CourtSimConfig
160
+ from cli.config import load_simulate_config, SimulateConfig
 
161
 
162
  # Resolve parameters: config -> interactive -> flags
163
  if config:
164
  scfg = load_simulate_config(config)
165
+ # CLI flags override config if provided
166
  scfg = scfg.model_copy(update={
167
  "cases": Path(cases_csv) if cases_csv else scfg.cases,
168
  "days": days if days else scfg.days,
 
219
  res = sim.run()
220
  progress.update(task, completed=True)
221
 
222
+ # Display results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  console.print("\n[bold green]Simulation Complete![/bold green]")
224
  console.print(f"\nHorizon: {cfg.start} \u2192 {res.end_date} ({days} days)")
225
  console.print(f"\n[bold]Hearing Metrics:[/bold]")
 
227
  console.print(f" Heard: {res.hearings_heard:,} ({res.hearings_heard/max(1,res.hearings_total):.1%})")
228
  console.print(f" Adjourned: {res.hearings_adjourned:,} ({res.hearings_adjourned/max(1,res.hearings_total):.1%})")
229
 
230
+ disp_times = [(c.disposal_date - c.filed_date).days for c in cases
231
+ if c.disposal_date is not None and c.status == CaseStatus.DISPOSED]
232
+ gini_disp = gini(disp_times) if disp_times else 0.0
233
+
234
  console.print(f"\n[bold]Disposal Metrics:[/bold]")
235
  console.print(f" Cases disposed: {res.disposals:,} ({res.disposals/len(cases):.1%})")
236
  console.print(f" Gini coefficient: {gini_disp:.3f}")
 
241
 
242
  if log_dir:
243
  console.print(f"\n[bold cyan]Output Files:[/bold cyan]")
244
+ console.print(f" - {log_dir}/report.txt")
245
+ console.print(f" - {log_dir}/metrics.csv")
246
+ console.print(f" - {log_dir}/events.csv")
247
 
248
  except Exception as e:
249
  console.print(f"[bold red]Error:[/bold red] {e}")
250
  raise typer.Exit(code=1)
251
 
252
 
253
+ @app.command()
254
+ def train(
255
+ episodes: int = typer.Option(20, "--episodes", "-e", help="Number of training episodes"),
256
+ cases_per_episode: int = typer.Option(200, "--cases", "-n", help="Cases per episode"),
257
+ learning_rate: float = typer.Option(0.15, "--lr", help="Learning rate"),
258
+ epsilon: float = typer.Option(0.4, "--epsilon", help="Initial epsilon for exploration"),
259
+ output: str = typer.Option("models/rl_agent.pkl", "--output", "-o", help="Output model file"),
260
+ seed: int = typer.Option(42, "--seed", help="Random seed"),
261
+ ) -> None:
262
+ """Train RL agent for case scheduling."""
263
+ console.print(f"[bold blue]Training RL Agent ({episodes} episodes)[/bold blue]")
264
+
265
+ try:
266
+ from rl.simple_agent import TabularQAgent
267
+ from rl.training import train_agent
268
+ from rl.config import RLTrainingConfig
269
+ import pickle
270
+
271
+ # Create agent
272
+ agent = TabularQAgent(learning_rate=learning_rate, epsilon=epsilon, discount=0.95)
273
+
274
+ # Configure training
275
+ config = RLTrainingConfig(
276
+ episodes=episodes,
277
+ cases_per_episode=cases_per_episode,
278
+ training_seed=seed,
279
+ initial_epsilon=epsilon,
280
+ learning_rate=learning_rate,
281
+ )
282
+
283
+ with Progress(
284
+ SpinnerColumn(),
285
+ TextColumn("[progress.description]{task.description}"),
286
+ console=console,
287
+ ) as progress:
288
+ task = progress.add_task(f"Training {episodes} episodes...", total=None)
289
+ stats = train_agent(agent, rl_config=config, verbose=False)
290
+ progress.update(task, completed=True)
291
+
292
+ # Save model
293
+ output_path = Path(output)
294
+ output_path.parent.mkdir(parents=True, exist_ok=True)
295
+ with output_path.open("wb") as f:
296
+ pickle.dump(agent, f)
297
+
298
+ console.print("\n[bold green]\u2713 Training Complete![/bold green]")
299
+ console.print(f"\nFinal Statistics:")
300
+ console.print(f" Episodes: {len(stats['episodes'])}")
301
+ console.print(f" Final disposal rate: {stats['disposal_rates'][-1]:.1%}")
302
+ console.print(f" States explored: {stats['states_explored'][-1]:,}")
303
+ console.print(f" Q-table size: {len(agent.q_table):,}")
304
+ console.print(f"\nModel saved to: {output_path}")
305
+
306
+ except Exception as e:
307
+ console.print(f"[bold red]Error:[/bold red] {e}")
308
+ raise typer.Exit(code=1)
309
+
310
+
311
  @app.command()
312
  def workflow(
313
  n_cases: int = typer.Option(10000, "--cases", "-n", help="Number of cases to generate"),
 
373
  @app.command()
374
  def version() -> None:
375
  """Show version information."""
 
376
  console.print(f"Court Scheduler CLI v{__version__}")
377
  console.print("Court Scheduling System for Karnataka High Court")
378
 
court_scheduler/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- """Court Scheduler CLI Package.
2
-
3
- This package provides a unified command-line interface for the Court Scheduling System.
4
- """
5
-
6
- __version__ = "0.1.0-dev.1"
 
 
 
 
 
 
 
court_scheduler/config_loader.py DELETED
@@ -1,32 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import tomllib
5
- from pathlib import Path
6
- from typing import Any, Dict, Literal
7
-
8
- from .config_models import GenerateConfig, SimulateConfig, WorkflowConfig
9
-
10
-
11
- def _read_config(path: Path) -> Dict[str, Any]:
12
- suf = path.suffix.lower()
13
- if suf == ".json":
14
- return json.loads(path.read_text(encoding="utf-8"))
15
- if suf == ".toml":
16
- return tomllib.loads(path.read_text(encoding="utf-8"))
17
- raise ValueError(f"Unsupported config format: {path.suffix}. Use .toml or .json")
18
-
19
-
20
- def load_generate_config(path: Path) -> GenerateConfig:
21
- data = _read_config(path)
22
- return GenerateConfig(**data)
23
-
24
-
25
- def load_simulate_config(path: Path) -> SimulateConfig:
26
- data = _read_config(path)
27
- return SimulateConfig(**data)
28
-
29
-
30
- def load_workflow_config(path: Path) -> WorkflowConfig:
31
- data = _read_config(path)
32
- return WorkflowConfig(**data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
court_scheduler_rl.py DELETED
@@ -1,680 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Court Scheduling System - Comprehensive RL Pipeline
4
- Interactive CLI for 2-year simulation with daily cause list generation
5
-
6
- Designed for Karnataka High Court hackathon submission.
7
- """
8
-
9
- import sys
10
- import json
11
- import time
12
- from datetime import date, datetime, timedelta
13
- from pathlib import Path
14
- from typing import Dict, Any, Optional, List
15
- import argparse
16
- from dataclasses import dataclass, asdict, field
17
-
18
- import typer
19
- from rich.console import Console
20
- from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
21
- from rich.table import Table
22
- from rich.panel import Panel
23
- from rich.text import Text
24
- from rich.prompt import Prompt, Confirm, IntPrompt, FloatPrompt
25
- from rich import box
26
-
27
- # Initialize
28
- console = Console()
29
- app = typer.Typer(name="court-scheduler-rl", help="Interactive RL Court Scheduling Pipeline")
30
-
31
- @dataclass
32
- class PipelineConfig:
33
- """Complete pipeline configuration"""
34
- # Data Generation
35
- n_cases: int = 50000
36
- start_date: str = "2022-01-01"
37
- end_date: str = "2023-12-31"
38
- stage_mix: str = "auto"
39
- seed: int = 42
40
-
41
- # RL Training - delegate to RLTrainingConfig
42
- rl_training: "RLTrainingConfig" = None # Will be set in __post_init__
43
-
44
- # Simulation
45
- sim_days: int = 730 # 2 years
46
- sim_start_date: Optional[str] = None
47
- policies: List[str] = None
48
-
49
- # Output (no longer user-configurable - managed by OutputManager)
50
- generate_cause_lists: bool = True
51
- generate_visualizations: bool = True
52
-
53
- def __post_init__(self):
54
- if self.policies is None:
55
- self.policies = ["readiness", "rl"]
56
-
57
- # Import here to avoid circular dependency
58
- if self.rl_training is None:
59
- from rl.config import DEFAULT_RL_TRAINING_CONFIG
60
- self.rl_training = DEFAULT_RL_TRAINING_CONFIG
61
-
62
- class InteractivePipeline:
63
- """Interactive pipeline orchestrator"""
64
-
65
- def __init__(self, config: PipelineConfig, run_id: str = None):
66
- self.config = config
67
-
68
- from scheduler.utils.output_manager import OutputManager
69
- self.output = OutputManager(run_id=run_id)
70
- self.output.create_structure()
71
- self.output.save_config(config)
72
-
73
- def run(self):
74
- """Execute complete pipeline"""
75
- console.print(Panel.fit(
76
- "[bold blue]Court Scheduling System - RL Pipeline[/bold blue]\n"
77
- "[yellow]Karnataka High Court Hackathon Submission[/yellow]",
78
- box=box.DOUBLE_EDGE
79
- ))
80
-
81
- try:
82
- # Pipeline steps
83
- self._step_1_eda()
84
- self._step_2_data_generation()
85
- self._step_3_rl_training()
86
- self._step_4_simulation()
87
- self._step_5_cause_lists()
88
- self._step_6_analysis()
89
- self._step_7_summary()
90
-
91
- except Exception as e:
92
- console.print(f"[bold red]Pipeline Error:[/bold red] {e}")
93
- sys.exit(1)
94
-
95
- def _step_1_eda(self):
96
- """Step 1: EDA Pipeline"""
97
- console.print("\n[bold cyan]Step 1/7: EDA & Parameter Extraction[/bold cyan]")
98
-
99
- # Check if EDA was run recently
100
- from src import eda_config
101
-
102
- param_dir = Path("reports/figures").glob("v0.4.0_*/params")
103
- recent_params = any(p.exists() and
104
- (datetime.now() - datetime.fromtimestamp(p.stat().st_mtime)).days < 1
105
- for p in param_dir)
106
-
107
- if recent_params and not Confirm.ask("EDA parameters found. Regenerate?", default=False):
108
- console.print(" [green]OK[/green] Using existing EDA parameters")
109
- self.output.record_eda_metadata(
110
- version=eda_config.VERSION,
111
- used_cached=True,
112
- params_path=self.output.eda_params,
113
- figures_path=self.output.eda_figures,
114
- )
115
- return
116
-
117
- with Progress(
118
- SpinnerColumn(),
119
- TextColumn("[progress.description]{task.description}"),
120
- console=console) as progress:
121
- task = progress.add_task("Running EDA pipeline...", total=None)
122
-
123
- # Configure EDA output paths
124
- from src.eda_config import set_output_paths
125
- set_output_paths(
126
- eda_dir=self.output.eda_figures,
127
- data_dir=self.output.eda_data,
128
- params_dir=self.output.eda_params
129
- )
130
-
131
- from src.eda_load_clean import run_load_and_clean
132
- from src.eda_exploration import run_exploration
133
- from src.eda_parameters import run_parameter_export
134
-
135
- run_load_and_clean()
136
- run_exploration()
137
- run_parameter_export()
138
-
139
- progress.update(task, completed=True)
140
-
141
- console.print(" [green]OK[/green] EDA pipeline complete")
142
- self.output.record_eda_metadata(
143
- version=eda_config.VERSION,
144
- used_cached=False,
145
- params_path=self.output.eda_params,
146
- figures_path=self.output.eda_figures,
147
- )
148
-
149
- def _step_2_data_generation(self):
150
- """Step 2: Generate Training Data"""
151
- console.print(f"\n[bold cyan]Step 2/7: Data Generation[/bold cyan]")
152
- console.print(f" Generating {self.config.n_cases:,} cases ({self.config.start_date} to {self.config.end_date})")
153
-
154
- cases_file = self.output.training_cases_file
155
-
156
- with Progress(
157
- SpinnerColumn(),
158
- TextColumn("[progress.description]{task.description}"),
159
- BarColumn(),
160
- console=console) as progress:
161
- task = progress.add_task("Generating cases...", total=100)
162
-
163
- from datetime import date as date_cls
164
- from scheduler.data.case_generator import CaseGenerator
165
-
166
- start = date_cls.fromisoformat(self.config.start_date)
167
- end = date_cls.fromisoformat(self.config.end_date)
168
-
169
- gen = CaseGenerator(start=start, end=end, seed=self.config.seed)
170
- cases = gen.generate(self.config.n_cases, stage_mix_auto=True)
171
-
172
- progress.update(task, advance=50)
173
-
174
- CaseGenerator.to_csv(cases, cases_file)
175
- progress.update(task, completed=100)
176
-
177
- console.print(f" [green]OK[/green] Generated {len(cases):,} cases -> {cases_file}")
178
- return cases
179
-
180
- def _step_3_rl_training(self):
181
- """Step 3: RL Agent Training"""
182
- console.print(f"\n[bold cyan]Step 3/7: RL Training[/bold cyan]")
183
- console.print(f" Episodes: {self.config.rl_training.episodes}, Learning Rate: {self.config.rl_training.learning_rate}")
184
-
185
- model_file = self.output.trained_model_file
186
-
187
- def _safe_mean(values: List[float]) -> float:
188
- return sum(values) / len(values) if values else 0.0
189
-
190
- with Progress(
191
- SpinnerColumn(),
192
- TextColumn("[progress.description]{task.description}"),
193
- BarColumn(),
194
- TimeElapsedColumn(),
195
- console=console) as progress:
196
- training_task = progress.add_task("Training RL agent...", total=self.config.rl_training.episodes)
197
-
198
- # Import training components
199
- from rl.training import train_agent
200
- from rl.simple_agent import TabularQAgent
201
- import pickle
202
-
203
- # Initialize agent with configured hyperparameters
204
- rl_cfg = self.config.rl_training
205
- agent = TabularQAgent(
206
- learning_rate=rl_cfg.learning_rate,
207
- epsilon=rl_cfg.initial_epsilon,
208
- discount=rl_cfg.discount_factor
209
- )
210
-
211
- # Training with progress updates
212
- # Note: train_agent handles its own progress internally
213
- rl_cfg = self.config.rl_training
214
- training_stats = train_agent(
215
- agent=agent,
216
- rl_config=rl_cfg,
217
- verbose=False # Disable internal printing
218
- )
219
-
220
- progress.update(training_task, completed=rl_cfg.episodes)
221
-
222
- # Save trained agent
223
- agent.save(model_file)
224
-
225
- # Persist training stats for downstream consumers
226
- self.output.save_training_stats(training_stats)
227
-
228
- # Run a lightweight evaluation sweep for summary metrics
229
- evaluation_stats = None
230
- try:
231
- from rl.training import evaluate_agent
232
- from scheduler.data.case_generator import CaseGenerator
233
-
234
- eval_gen = CaseGenerator(
235
- start=date.today(),
236
- end=date.today() + timedelta(days=60),
237
- seed=self.config.seed + 99,
238
- )
239
- eval_cases = eval_gen.generate(min(rl_cfg.cases_per_episode, 500), stage_mix_auto=True)
240
- evaluation_stats = evaluate_agent(
241
- agent=agent,
242
- test_cases=eval_cases,
243
- episodes=5,
244
- episode_length=rl_cfg.episode_length_days,
245
- rl_config=rl_cfg,
246
- )
247
- self.output.save_evaluation_stats(evaluation_stats)
248
- except Exception as eval_err:
249
- console.print(f" [yellow]WARNING[/yellow] Evaluation skipped: {eval_err}")
250
-
251
- training_summary = {
252
- "episodes": rl_cfg.episodes,
253
- "cases_per_episode": rl_cfg.cases_per_episode,
254
- "episode_length_days": rl_cfg.episode_length_days,
255
- "learning_rate": rl_cfg.learning_rate,
256
- "epsilon": {
257
- "initial": rl_cfg.initial_epsilon,
258
- "final": agent.epsilon,
259
- },
260
- "reward": {
261
- "mean": _safe_mean(training_stats.get("total_rewards", [])),
262
- "final": training_stats.get("total_rewards", [0])[-1] if training_stats.get("total_rewards") else 0.0,
263
- },
264
- "disposal_rate": {
265
- "mean": _safe_mean(training_stats.get("disposal_rates", [])),
266
- "final": training_stats.get("disposal_rates", [0])[-1] if training_stats.get("disposal_rates") else 0.0,
267
- },
268
- "states_explored_final": training_stats.get("states_explored", [len(agent.q_table)])[-1]
269
- if training_stats.get("states_explored")
270
- else len(agent.q_table),
271
- "q_table_size": len(agent.q_table),
272
- "total_updates": getattr(agent, "total_updates", 0),
273
- }
274
-
275
- self.output.record_training_summary(training_summary, evaluation_stats)
276
-
277
- # Create symlink in models/ for backwards compatibility
278
- self.output.create_model_symlink()
279
-
280
- console.print(f" [green]OK[/green] Training complete -> {model_file}")
281
- console.print(f" [green]OK[/green] Model symlink: models/latest.pkl")
282
- console.print(f" [green]OK[/green] Final epsilon: {agent.epsilon:.4f}, States explored: {len(agent.q_table)}")
283
-
284
- # Store model path for simulation step
285
- self.trained_model_path = model_file
286
-
287
- def _step_4_simulation(self):
288
- """Step 4: 2-Year Simulation"""
289
- console.print(f"\n[bold cyan]Step 4/7: 2-Year Simulation[/bold cyan]")
290
- console.print(f" Duration: {self.config.sim_days} days ({self.config.sim_days/365:.1f} years)")
291
-
292
- # Load cases
293
- cases_file = self.output.training_cases_file
294
- from scheduler.data.case_generator import CaseGenerator
295
- cases = CaseGenerator.from_csv(cases_file)
296
-
297
- sim_start = date.fromisoformat(self.config.sim_start_date) if self.config.sim_start_date else max(c.filed_date for c in cases)
298
-
299
- # Run simulations for each policy
300
- results = {}
301
-
302
- for policy in self.config.policies:
303
- console.print(f"\n Running {policy} policy simulation...")
304
-
305
- policy_dir = self.output.get_policy_dir(policy)
306
- policy_dir.mkdir(exist_ok=True)
307
-
308
- # CRITICAL: Deep copy cases for each simulation to prevent state pollution
309
- # Cases are mutated during simulation (status, hearing_count, disposal_date)
310
- from copy import deepcopy
311
- policy_cases = deepcopy(cases)
312
-
313
- with Progress(
314
- SpinnerColumn(),
315
- TextColumn(f"[progress.description]Simulating {policy}..."),
316
- BarColumn(),
317
- console=console) as progress:
318
- task = progress.add_task("Simulating...", total=100)
319
-
320
- from scheduler.simulation.engine import CourtSim, CourtSimConfig
321
-
322
- # Prepare config with RL model path if needed
323
- cfg_kwargs = {
324
- "start": sim_start,
325
- "days": self.config.sim_days,
326
- "seed": self.config.seed,
327
- "policy": policy,
328
- "duration_percentile": "median",
329
- "log_dir": policy_dir,
330
- }
331
-
332
- # Add RL agent path for RL policy
333
- if policy == "rl" and hasattr(self, 'trained_model_path'):
334
- cfg_kwargs["rl_agent_path"] = self.trained_model_path
335
-
336
- cfg = CourtSimConfig(**cfg_kwargs)
337
-
338
- sim = CourtSim(cfg, policy_cases)
339
- result = sim.run()
340
-
341
- progress.update(task, completed=100)
342
-
343
- results[policy] = {
344
- 'result': result,
345
- 'cases': policy_cases, # Use the deep-copied cases for this simulation
346
- 'sim': sim,
347
- 'dir': policy_dir
348
- }
349
-
350
- console.print(f" [green]OK[/green] {result.disposals:,} disposals ({result.disposals/len(cases):.1%})")
351
-
352
- allocator_stats = sim.allocator.get_utilization_stats()
353
- backlog = sum(1 for c in policy_cases if not c.is_disposed)
354
-
355
- kpis = {
356
- "policy": policy,
357
- "disposals": result.disposals,
358
- "disposal_rate": result.disposals / len(policy_cases),
359
- "utilization": result.utilization,
360
- "hearings_total": result.hearings_total,
361
- "hearings_heard": result.hearings_heard,
362
- "hearings_adjourned": result.hearings_adjourned,
363
- "backlog": backlog,
364
- "backlog_rate": backlog / len(policy_cases) if policy_cases else 0,
365
- "fairness_gini": allocator_stats.get("load_balance_gini"),
366
- "avg_daily_load": allocator_stats.get("avg_daily_load"),
367
- "capacity_rejections": allocator_stats.get("capacity_rejections"),
368
- }
369
-
370
- self.output.record_simulation_kpis(policy, kpis)
371
-
372
- self.sim_results = results
373
- console.print(f" [green]OK[/green] All simulations complete")
374
-
375
- def _step_5_cause_lists(self):
376
- """Step 5: Daily Cause List Generation"""
377
- if not self.config.generate_cause_lists:
378
- console.print("\n[bold cyan]Step 5/7: Cause Lists[/bold cyan] [dim](skipped)[/dim]")
379
- return
380
-
381
- console.print(f"\n[bold cyan]Step 5/7: Daily Cause List Generation[/bold cyan]")
382
-
383
- for policy, data in self.sim_results.items():
384
- console.print(f" Generating cause lists for {policy} policy...")
385
-
386
- with Progress(
387
- SpinnerColumn(),
388
- TextColumn("[progress.description]{task.description}"),
389
- console=console) as progress:
390
- task = progress.add_task("Generating cause lists...", total=None)
391
-
392
- from scheduler.output.cause_list import CauseListGenerator
393
-
394
- events_file = data['dir'] / "events.csv"
395
- if events_file.exists():
396
- output_dir = data['dir'] / "cause_lists"
397
- generator = CauseListGenerator(events_file)
398
- cause_list_file = generator.generate_daily_lists(output_dir)
399
-
400
- console.print(f" [green]OK[/green] Generated -> {cause_list_file}")
401
- else:
402
- console.print(f" [yellow]WARNING[/yellow] No events file found for {policy}")
403
-
404
- progress.update(task, completed=True)
405
-
406
- def _step_6_analysis(self):
407
- """Step 6: Performance Analysis"""
408
- console.print(f"\n[bold cyan]Step 6/7: Performance Analysis[/bold cyan]")
409
-
410
- with Progress(
411
- SpinnerColumn(),
412
- TextColumn("[progress.description]{task.description}"),
413
- console=console) as progress:
414
- task = progress.add_task("Analyzing results...", total=None)
415
-
416
- # Generate comparison report
417
- self._generate_comparison_report()
418
-
419
- # Generate visualizations if requested
420
- if self.config.generate_visualizations:
421
- self._generate_visualizations()
422
-
423
- progress.update(task, completed=True)
424
-
425
- console.print(" [green]OK[/green] Analysis complete")
426
-
427
- def _step_7_summary(self):
428
- """Step 7: Executive Summary"""
429
- console.print(f"\n[bold cyan]Step 7/7: Executive Summary[/bold cyan]")
430
-
431
- summary = self._generate_executive_summary()
432
-
433
- # Save summary
434
- summary_file = self.output.executive_summary_file
435
- with open(summary_file, 'w') as f:
436
- f.write(summary)
437
-
438
- # Display key metrics
439
- table = Table(title="Hackathon Submission Results", box=box.ROUNDED)
440
- table.add_column("Metric", style="bold")
441
- table.add_column("RL Agent", style="green")
442
- table.add_column("Baseline", style="blue")
443
- table.add_column("Improvement", style="magenta")
444
-
445
- if "rl" in self.sim_results and "readiness" in self.sim_results:
446
- rl_result = self.sim_results["rl"]["result"]
447
- baseline_result = self.sim_results["readiness"]["result"]
448
-
449
- rl_disposal_rate = rl_result.disposals / len(self.sim_results["rl"]["cases"])
450
- baseline_disposal_rate = baseline_result.disposals / len(self.sim_results["readiness"]["cases"])
451
-
452
- table.add_row(
453
- "Disposal Rate",
454
- f"{rl_disposal_rate:.1%}",
455
- f"{baseline_disposal_rate:.1%}",
456
- f"{((rl_disposal_rate - baseline_disposal_rate) / baseline_disposal_rate * 100):+.2f}%"
457
- )
458
-
459
- table.add_row(
460
- "Cases Disposed",
461
- f"{rl_result.disposals:,}",
462
- f"{baseline_result.disposals:,}",
463
- f"{rl_result.disposals - baseline_result.disposals:+,}"
464
- )
465
-
466
- table.add_row(
467
- "Utilization",
468
- f"{rl_result.utilization:.1%}",
469
- f"{baseline_result.utilization:.1%}",
470
- f"{((rl_result.utilization - baseline_result.utilization) / baseline_result.utilization * 100):+.2f}%"
471
- )
472
-
473
- console.print(table)
474
-
475
- console.print(Panel.fit(
476
- f"[bold green]Pipeline Complete![/bold green]\n\n"
477
- f"Results: {self.output.run_dir}/\n"
478
- f"Executive Summary: {summary_file}\n"
479
- f"Visualizations: {self.output.visualizations_dir}/\n"
480
- f"Cause Lists: {self.output.simulation_dir}/*/cause_lists/\n\n"
481
- f"[yellow]Ready for hackathon submission![/yellow]",
482
- box=box.DOUBLE_EDGE
483
- ))
484
-
485
- def _generate_comparison_report(self):
486
- """Generate detailed comparison report"""
487
- report_file = self.output.comparison_report_file
488
-
489
- with open(report_file, 'w') as f:
490
- f.write("# Court Scheduling System - Performance Comparison\n\n")
491
- f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
492
-
493
- f.write("## Configuration\n\n")
494
- f.write(f"- Training Cases: {self.config.n_cases:,}\n")
495
- f.write(f"- Simulation Period: {self.config.sim_days} days ({self.config.sim_days/365:.1f} years)\n")
496
- f.write(f"- RL Episodes: {self.config.rl_training.episodes}\n")
497
- f.write(f"- RL Learning Rate: {self.config.rl_training.learning_rate}\n")
498
- f.write(f"- RL Epsilon: {self.config.rl_training.initial_epsilon}\n")
499
- f.write(f"- Policies Compared: {', '.join(self.config.policies)}\n\n")
500
-
501
- f.write("## Results Summary\n\n")
502
- f.write("| Policy | Disposals | Disposal Rate | Utilization | Avg Hearings/Day |\n")
503
- f.write("|--------|-----------|---------------|-------------|------------------|\n")
504
-
505
- for policy, data in self.sim_results.items():
506
- result = data['result']
507
- cases = data['cases']
508
- disposal_rate = result.disposals / len(cases)
509
- hearings_per_day = result.hearings_total / self.config.sim_days
510
-
511
- f.write(f"| {policy.title()} | {result.disposals:,} | {disposal_rate:.1%} | {result.utilization:.1%} | {hearings_per_day:.1f} |\n")
512
-
513
- def _generate_visualizations(self):
514
- """Generate performance visualizations"""
515
- viz_dir = self.output.visualizations_dir
516
- viz_dir.mkdir(exist_ok=True)
517
-
518
- # This would generate charts comparing policies
519
- # For now, we'll create placeholder
520
- with open(viz_dir / "performance_charts.md", 'w') as f:
521
- f.write("# Performance Visualizations\n\n")
522
- f.write("Generated charts showing:\n")
523
- f.write("- Daily disposal rates\n")
524
- f.write("- Court utilization over time\n")
525
- f.write("- Case type performance\n")
526
- f.write("- Load balancing effectiveness\n")
527
-
528
- def _generate_executive_summary(self) -> str:
529
- """Generate executive summary for hackathon submission"""
530
- if "rl" not in self.sim_results:
531
- return "# Executive Summary\n\nSimulation completed successfully."
532
-
533
- rl_data = self.sim_results["rl"]
534
- result = rl_data["result"]
535
- cases = rl_data["cases"]
536
-
537
- disposal_rate = result.disposals / len(cases)
538
-
539
- summary = f"""# Court Scheduling System - Executive Summary
540
-
541
- ## Hackathon Submission: Karnataka High Court
542
-
543
- ### System Overview
544
- This intelligent court scheduling system uses Reinforcement Learning to optimize case allocation and improve judicial efficiency. The system was evaluated using a comprehensive 2-year simulation with {len(cases):,} real cases.
545
-
546
- ### Key Achievements
547
-
548
- **{disposal_rate:.1%} Case Disposal Rate** - Significantly improved case clearance
549
- **{result.utilization:.1%} Court Utilization** - Optimal resource allocation
550
- **{result.hearings_total:,} Hearings Scheduled** - Over {self.config.sim_days} days
551
- **AI-Powered Decisions** - Reinforcement learning with {self.config.rl_training.episodes} training episodes
552
-
553
- ### Technical Innovation
554
-
555
- - **Reinforcement Learning**: Tabular Q-learning with 6D state space
556
- - **Real-time Adaptation**: Dynamic policy adjustment based on case characteristics
557
- - **Multi-objective Optimization**: Balances disposal rate, fairness, and utilization
558
- - **Production Ready**: Generates daily cause lists for immediate deployment
559
-
560
- ### Impact Metrics
561
-
562
- - **Cases Disposed**: {result.disposals:,} out of {len(cases):,}
563
- - **Average Hearings per Day**: {result.hearings_total/self.config.sim_days:.1f}
564
- - **System Scalability**: Handles 50,000+ case simulations efficiently
565
- - **Judicial Time Saved**: Estimated {(result.utilization * self.config.sim_days):.0f} productive court days
566
-
567
- ### Deployment Readiness
568
-
569
- **Daily Cause Lists**: Automated generation for {self.config.sim_days} days
570
- **Performance Monitoring**: Comprehensive metrics and analytics
571
- **Judicial Override**: Complete control system for judge approval
572
- **Multi-courtroom Support**: Load-balanced allocation across courtrooms
573
-
574
- ### Next Steps
575
-
576
- 1. **Pilot Deployment**: Begin with select courtrooms for validation
577
- 2. **Judge Training**: Familiarization with AI-assisted scheduling
578
- 3. **Performance Monitoring**: Track real-world improvement metrics
579
- 4. **System Expansion**: Scale to additional court complexes
580
-
581
- ---
582
-
583
- **Generated**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
584
- **System Version**: 2.0 (Hackathon Submission)
585
- **Contact**: Karnataka High Court Digital Innovation Team
586
- """
587
-
588
- return summary
589
-
590
- def get_interactive_config() -> PipelineConfig:
591
- """Get configuration through interactive prompts"""
592
- console.print("[bold blue]Interactive Pipeline Configuration[/bold blue]\n")
593
-
594
- # Data Generation
595
- console.print("[bold]Data Generation[/bold]")
596
- n_cases = IntPrompt.ask("Number of cases to generate", default=50000)
597
- start_date = Prompt.ask("Start date (YYYY-MM-DD)", default="2022-01-01")
598
- end_date = Prompt.ask("End date (YYYY-MM-DD)", default="2023-12-31")
599
-
600
- # RL Training
601
- console.print("\n[bold]RL Training[/bold]")
602
- from rl.config import RLTrainingConfig
603
-
604
- episodes = IntPrompt.ask("Training episodes", default=100)
605
- learning_rate = FloatPrompt.ask("Learning rate", default=0.15)
606
-
607
- rl_training_config = RLTrainingConfig(
608
- episodes=episodes,
609
- learning_rate=learning_rate)
610
-
611
- # Simulation
612
- console.print("\n[bold]Simulation[/bold]")
613
- sim_days = IntPrompt.ask("Simulation days (730 = 2 years)", default=730)
614
-
615
- policies = ["readiness", "rl"]
616
- if Confirm.ask("Include additional policies? (FIFO, Age)", default=False):
617
- policies.extend(["fifo", "age"])
618
-
619
- # Output
620
- console.print("\n[bold]Output Options[/bold]")
621
- generate_cause_lists = Confirm.ask("Generate daily cause lists?", default=True)
622
- generate_visualizations = Confirm.ask("Generate performance visualizations?", default=True)
623
-
624
- return PipelineConfig(
625
- n_cases=n_cases,
626
- start_date=start_date,
627
- end_date=end_date,
628
- rl_training=rl_training_config,
629
- sim_days=sim_days,
630
- policies=policies,
631
- generate_cause_lists=generate_cause_lists,
632
- generate_visualizations=generate_visualizations)
633
-
634
- @app.command()
635
- def interactive():
636
- """Run interactive pipeline configuration and execution"""
637
- config = get_interactive_config()
638
-
639
- # Confirm configuration
640
- console.print(f"\n[bold yellow]Configuration Summary:[/bold yellow]")
641
- console.print(f" Cases: {config.n_cases:,}")
642
- console.print(f" Period: {config.start_date} to {config.end_date}")
643
- console.print(f" RL Episodes: {config.rl_training.episodes}")
644
- console.print(f" RL Learning Rate: {config.rl_training.learning_rate}")
645
- console.print(f" Simulation: {config.sim_days} days")
646
- console.print(f" Policies: {', '.join(config.policies)}")
647
- console.print(f" Output: outputs/runs/run_<timestamp>/")
648
-
649
- if not Confirm.ask("\nProceed with this configuration?", default=True):
650
- console.print("Cancelled.")
651
- return
652
-
653
- # Execute pipeline (OutputManager handles output structure)
654
- pipeline = InteractivePipeline(config)
655
- start_time = time.time()
656
-
657
- console.print(f"\n[dim]Run directory: {pipeline.output.run_dir}[/dim]\n")
658
-
659
- pipeline.run()
660
-
661
- elapsed = time.time() - start_time
662
- console.print(f"\n[green]Pipeline completed in {elapsed/60:.1f} minutes[/green]")
663
-
664
- @app.command()
665
- def quick():
666
- """Run quick demo with default parameters"""
667
- console.print("[bold blue]Quick Demo Pipeline[/bold blue]\n")
668
-
669
- from rl.config import QUICK_DEMO_RL_CONFIG
670
-
671
- config = PipelineConfig(
672
- n_cases=10000,
673
- rl_training=QUICK_DEMO_RL_CONFIG,
674
- sim_days=90)
675
-
676
- pipeline = InteractivePipeline(config)
677
- pipeline.run()
678
-
679
- if __name__ == "__main__":
680
- app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,11 +0,0 @@
1
- #!/usr/bin/env python
2
- """Main entry point for Court Scheduling System.
3
-
4
- This file provides the primary entry point for the project.
5
- It invokes the CLI which provides all scheduling system operations.
6
- """
7
-
8
- from court_scheduler.cli import main
9
-
10
- if __name__ == "__main__":
11
- main()
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -35,7 +35,7 @@ dev = [
35
  ]
36
 
37
  [project.scripts]
38
- court-scheduler = "court_scheduler.cli:app"
39
 
40
  [build-system]
41
  requires = ["hatchling"]
 
35
  ]
36
 
37
  [project.scripts]
38
+ court-scheduler = "cli.main:app"
39
 
40
  [build-system]
41
  requires = ["hatchling"]
rl/training.py CHANGED
@@ -11,6 +11,7 @@ from datetime import date, datetime, timedelta
11
  import random
12
 
13
  from scheduler.data.case_generator import CaseGenerator
 
14
  from scheduler.core.case import Case, CaseStatus
15
  from scheduler.core.algorithm import SchedulingAlgorithm
16
  from scheduler.core.courtroom import Courtroom
@@ -38,6 +39,7 @@ class RLTrainingEnvironment:
38
  horizon_days: int = 90,
39
  rl_config: RLTrainingConfig | None = None,
40
  policy_config: PolicyConfig | None = None,
 
41
  ):
42
  """Initialize training environment.
43
 
@@ -47,6 +49,7 @@ class RLTrainingEnvironment:
47
  horizon_days: Training episode length in days
48
  rl_config: RL-specific training constraints
49
  policy_config: Policy knobs for ripeness/gap rules
 
50
  """
51
  self.cases = cases
52
  self.start_date = start_date
@@ -56,6 +59,7 @@ class RLTrainingEnvironment:
56
  self.rl_config = rl_config or DEFAULT_RL_TRAINING_CONFIG
57
  self.policy_config = policy_config or DEFAULT_POLICY_CONFIG
58
  self.reward_helper = EpisodeRewardHelper(total_cases=len(cases))
 
59
 
60
  # Resources mirroring production defaults
61
  self.courtrooms = [
@@ -193,49 +197,71 @@ class RLTrainingEnvironment:
193
  return self.cases, rewards, episode_done
194
 
195
  def _simulate_hearing_outcome(self, case: Case) -> str:
196
- """Simulate hearing outcome based on stage and case characteristics."""
197
- # Simplified outcome simulation
 
 
 
198
  current_stage = case.current_stage
 
 
 
 
 
 
 
 
199
 
200
- # Terminal stages - high disposal probability
 
201
  if current_stage in ["ORDERS / JUDGMENT", "FINAL DISPOSAL"]:
202
- if random.random() < 0.7: # 70% chance of disposal
203
- return "FINAL DISPOSAL"
204
- else:
205
- return "ADJOURNED"
206
-
207
- # Early stages more likely to adjourn
208
- if current_stage in ["PRE-ADMISSION", "ADMISSION"]:
209
- if random.random() < 0.6: # 60% adjournment rate
210
- return "ADJOURNED"
211
- else:
212
- # Progress to next logical stage
213
- if current_stage == "PRE-ADMISSION":
214
- return "ADMISSION"
215
- else:
216
- return "EVIDENCE"
217
-
218
- # Mid-stages
219
- if current_stage in ["EVIDENCE", "ARGUMENTS"]:
220
- if random.random() < 0.4: # 40% adjournment rate
221
- return "ADJOURNED"
222
- else:
223
- if current_stage == "EVIDENCE":
224
- return "ARGUMENTS"
225
- else:
226
- return "ORDERS / JUDGMENT"
227
-
228
- # Default progression
229
- return "ARGUMENTS"
 
230
 
231
 
232
  def train_agent(
233
  agent: TabularQAgent,
234
  rl_config: RLTrainingConfig = DEFAULT_RL_TRAINING_CONFIG,
235
  policy_config: PolicyConfig = DEFAULT_POLICY_CONFIG,
 
236
  verbose: bool = True,
237
  ) -> Dict:
238
- """Train RL agent using episodic simulation with courtroom constraints."""
 
 
 
 
 
 
 
 
239
  config = rl_config or DEFAULT_RL_TRAINING_CONFIG
240
  policy_cfg = policy_config or DEFAULT_POLICY_CONFIG
241
 
@@ -274,6 +300,7 @@ def train_agent(
274
  config.episode_length_days,
275
  rl_config=config,
276
  policy_config=policy_cfg,
 
277
  )
278
 
279
  # Reset environment
@@ -373,8 +400,19 @@ def evaluate_agent(
373
  episode_length: Optional[int] = None,
374
  rl_config: RLTrainingConfig = DEFAULT_RL_TRAINING_CONFIG,
375
  policy_config: PolicyConfig = DEFAULT_POLICY_CONFIG,
 
376
  ) -> Dict:
377
- """Evaluate trained agent performance."""
 
 
 
 
 
 
 
 
 
 
378
  # Set agent to evaluation mode (no exploration)
379
  original_epsilon = agent.epsilon
380
  agent.epsilon = 0.0
@@ -404,6 +442,7 @@ def evaluate_agent(
404
  eval_length,
405
  rl_config=config,
406
  policy_config=policy_cfg,
 
407
  )
408
 
409
  episode_cases = env.reset()
 
11
  import random
12
 
13
  from scheduler.data.case_generator import CaseGenerator
14
+ from scheduler.data.param_loader import ParameterLoader
15
  from scheduler.core.case import Case, CaseStatus
16
  from scheduler.core.algorithm import SchedulingAlgorithm
17
  from scheduler.core.courtroom import Courtroom
 
39
  horizon_days: int = 90,
40
  rl_config: RLTrainingConfig | None = None,
41
  policy_config: PolicyConfig | None = None,
42
+ params_dir: Optional[Path] = None,
43
  ):
44
  """Initialize training environment.
45
 
 
49
  horizon_days: Training episode length in days
50
  rl_config: RL-specific training constraints
51
  policy_config: Policy knobs for ripeness/gap rules
52
+ params_dir: Directory with EDA parameters (uses latest if None)
53
  """
54
  self.cases = cases
55
  self.start_date = start_date
 
59
  self.rl_config = rl_config or DEFAULT_RL_TRAINING_CONFIG
60
  self.policy_config = policy_config or DEFAULT_POLICY_CONFIG
61
  self.reward_helper = EpisodeRewardHelper(total_cases=len(cases))
62
+ self.param_loader = ParameterLoader(params_dir)
63
 
64
  # Resources mirroring production defaults
65
  self.courtrooms = [
 
197
  return self.cases, rewards, episode_done
198
 
199
  def _simulate_hearing_outcome(self, case: Case) -> str:
200
+ """Simulate hearing outcome using EDA-derived parameters.
201
+
202
+ Uses param_loader for adjournment probabilities and stage transitions
203
+ instead of hardcoded values, ensuring training aligns with production.
204
+ """
205
  current_stage = case.current_stage
206
+ case_type = case.case_type
207
+
208
+ # Query EDA-derived adjournment probability
209
+ p_adjourn = self.param_loader.get_adjournment_prob(current_stage, case_type)
210
+
211
+ # Sample adjournment
212
+ if random.random() < p_adjourn:
213
+ return "ADJOURNED"
214
 
215
+ # Case progresses - determine next stage using EDA-derived transitions
216
+ # Terminal stages lead to disposal
217
  if current_stage in ["ORDERS / JUDGMENT", "FINAL DISPOSAL"]:
218
+ return "FINAL DISPOSAL"
219
+
220
+ # Sample next stage using cumulative transition probabilities
221
+ transitions = self.param_loader.get_stage_transitions_fast(current_stage)
222
+ if not transitions:
223
+ # No transition data - use fallback progression
224
+ return self._fallback_stage_progression(current_stage)
225
+
226
+ # Sample from cumulative probabilities
227
+ rand_val = random.random()
228
+ for next_stage, cum_prob in transitions:
229
+ if rand_val <= cum_prob:
230
+ return next_stage
231
+
232
+ # Fallback if sampling fails (shouldn't happen with normalized probs)
233
+ return transitions[-1][0] if transitions else current_stage
234
+
235
+ def _fallback_stage_progression(self, current_stage: str) -> str:
236
+ """Fallback stage progression when no transition data available."""
237
+ progression_map = {
238
+ "PRE-ADMISSION": "ADMISSION",
239
+ "ADMISSION": "EVIDENCE",
240
+ "FRAMING OF CHARGES": "EVIDENCE",
241
+ "EVIDENCE": "ARGUMENTS",
242
+ "ARGUMENTS": "ORDERS / JUDGMENT",
243
+ "INTERLOCUTORY APPLICATION": "ARGUMENTS",
244
+ "SETTLEMENT": "FINAL DISPOSAL",
245
+ }
246
+ return progression_map.get(current_stage, "ARGUMENTS")
247
 
248
 
249
  def train_agent(
250
  agent: TabularQAgent,
251
  rl_config: RLTrainingConfig = DEFAULT_RL_TRAINING_CONFIG,
252
  policy_config: PolicyConfig = DEFAULT_POLICY_CONFIG,
253
+ params_dir: Optional[Path] = None,
254
  verbose: bool = True,
255
  ) -> Dict:
256
+ """Train RL agent using episodic simulation with courtroom constraints.
257
+
258
+ Args:
259
+ agent: TabularQAgent to train
260
+ rl_config: RL training configuration
261
+ policy_config: Policy configuration
262
+ params_dir: Directory with EDA parameters (uses latest if None)
263
+ verbose: Print training progress
264
+ """
265
  config = rl_config or DEFAULT_RL_TRAINING_CONFIG
266
  policy_cfg = policy_config or DEFAULT_POLICY_CONFIG
267
 
 
300
  config.episode_length_days,
301
  rl_config=config,
302
  policy_config=policy_cfg,
303
+ params_dir=params_dir,
304
  )
305
 
306
  # Reset environment
 
400
  episode_length: Optional[int] = None,
401
  rl_config: RLTrainingConfig = DEFAULT_RL_TRAINING_CONFIG,
402
  policy_config: PolicyConfig = DEFAULT_POLICY_CONFIG,
403
+ params_dir: Optional[Path] = None,
404
  ) -> Dict:
405
+ """Evaluate trained agent performance.
406
+
407
+ Args:
408
+ agent: Trained TabularQAgent to evaluate
409
+ test_cases: Cases to evaluate on
410
+ episodes: Number of evaluation episodes (default 10)
411
+ episode_length: Length of each episode in days
412
+ rl_config: RL configuration
413
+ policy_config: Policy configuration
414
+ params_dir: Directory with EDA parameters (uses latest if None)
415
+ """
416
  # Set agent to evaluation mode (no exploration)
417
  original_epsilon = agent.epsilon
418
  agent.epsilon = 0.0
 
442
  eval_length,
443
  rl_config=config,
444
  policy_config=policy_cfg,
445
+ params_dir=params_dir,
446
  )
447
 
448
  episode_cases = env.reset()
scheduler/core/ripeness.py CHANGED
@@ -53,7 +53,10 @@ RIPE_KEYWORDS = ["ARGUMENTS", "HEARING", "FINAL", "JUDGMENT", "ORDERS", "DISPOSA
53
 
54
 
55
  class RipenessClassifier:
56
- """Classify cases as RIPE or UNRIPE for scheduling optimization."""
 
 
 
57
 
58
  # Stages that indicate case is ready for substantive hearing
59
  RIPE_STAGES = [
@@ -72,6 +75,7 @@ class RipenessClassifier:
72
  ]
73
 
74
  # Minimum evidence thresholds before declaring a case RIPE
 
75
  MIN_SERVICE_HEARINGS = 1 # At least one hearing to confirm service/compliance
76
  MIN_STAGE_DAYS = 7 # Time spent in current stage to show compliance efforts
77
  MIN_CASE_AGE_DAYS = 14 # Minimum maturity before assuming readiness
@@ -262,3 +266,30 @@ class RipenessClassifier:
262
  }
263
 
264
  return estimates.get(ripeness, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
  class RipenessClassifier:
56
+ """Classify cases as RIPE or UNRIPE for scheduling optimization.
57
+
58
+ Thresholds can be adjusted dynamically based on accuracy feedback.
59
+ """
60
 
61
  # Stages that indicate case is ready for substantive hearing
62
  RIPE_STAGES = [
 
75
  ]
76
 
77
  # Minimum evidence thresholds before declaring a case RIPE
78
+ # These can be adjusted via set_thresholds() for calibration
79
  MIN_SERVICE_HEARINGS = 1 # At least one hearing to confirm service/compliance
80
  MIN_STAGE_DAYS = 7 # Time spent in current stage to show compliance efforts
81
  MIN_CASE_AGE_DAYS = 14 # Minimum maturity before assuming readiness
 
266
  }
267
 
268
  return estimates.get(ripeness, None)
269
+
270
+ @classmethod
271
+ def set_thresholds(cls, new_thresholds: dict[str, int | float]) -> None:
272
+ """Update classification thresholds for calibration.
273
+
274
+ Args:
275
+ new_thresholds: Dictionary with threshold names and values
276
+ e.g., {"MIN_SERVICE_HEARINGS": 2, "MIN_STAGE_DAYS": 5}
277
+ """
278
+ for threshold_name, value in new_thresholds.items():
279
+ if hasattr(cls, threshold_name):
280
+ setattr(cls, threshold_name, int(value))
281
+ else:
282
+ raise ValueError(f"Unknown threshold: {threshold_name}")
283
+
284
+ @classmethod
285
+ def get_current_thresholds(cls) -> dict[str, int]:
286
+ """Get current threshold values.
287
+
288
+ Returns:
289
+ Dictionary of threshold names and values
290
+ """
291
+ return {
292
+ "MIN_SERVICE_HEARINGS": cls.MIN_SERVICE_HEARINGS,
293
+ "MIN_STAGE_DAYS": cls.MIN_STAGE_DAYS,
294
+ "MIN_CASE_AGE_DAYS": cls.MIN_CASE_AGE_DAYS,
295
+ }
scheduler/monitoring/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Monitoring and feedback loop components."""
2
+
3
+ from scheduler.monitoring.ripeness_metrics import RipenessMetrics, RipenessPrediction
4
+ from scheduler.monitoring.ripeness_calibrator import RipenessCalibrator, ThresholdAdjustment
5
+
6
+ __all__ = [
7
+ "RipenessMetrics",
8
+ "RipenessPrediction",
9
+ "RipenessCalibrator",
10
+ "ThresholdAdjustment",
11
+ ]
scheduler/monitoring/ripeness_calibrator.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ripeness classifier calibration based on accuracy metrics.
2
+
3
+ Analyzes classification performance and suggests threshold adjustments
4
+ to improve accuracy over time.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+ from scheduler.monitoring.ripeness_metrics import RipenessMetrics
13
+
14
+
15
+ @dataclass
16
+ class ThresholdAdjustment:
17
+ """Suggested threshold adjustment with reasoning."""
18
+
19
+ threshold_name: str
20
+ current_value: int | float
21
+ suggested_value: int | float
22
+ reason: str
23
+ confidence: str # "high", "medium", "low"
24
+
25
+
26
+ class RipenessCalibrator:
27
+ """Analyzes ripeness metrics and suggests threshold calibration."""
28
+
29
+ # Calibration rules thresholds
30
+ HIGH_FALSE_POSITIVE_THRESHOLD = 0.20
31
+ HIGH_FALSE_NEGATIVE_THRESHOLD = 0.15
32
+ LOW_UNKNOWN_THRESHOLD = 0.05
33
+ LOW_RIPE_PRECISION_THRESHOLD = 0.70
34
+ LOW_UNRIPE_RECALL_THRESHOLD = 0.60
35
+
36
+ @classmethod
37
+ def analyze_metrics(
38
+ cls,
39
+ metrics: RipenessMetrics,
40
+ current_thresholds: Optional[dict[str, int | float]] = None,
41
+ ) -> list[ThresholdAdjustment]:
42
+ """Analyze metrics and suggest threshold adjustments.
43
+
44
+ Args:
45
+ metrics: RipenessMetrics with classification history
46
+ current_thresholds: Current threshold values (optional)
47
+
48
+ Returns:
49
+ List of suggested adjustments with reasoning
50
+ """
51
+ accuracy = metrics.get_accuracy_metrics()
52
+ adjustments: list[ThresholdAdjustment] = []
53
+
54
+ # Default current thresholds if not provided
55
+ if current_thresholds is None:
56
+ from scheduler.core.ripeness import RipenessClassifier
57
+ current_thresholds = {
58
+ "MIN_SERVICE_HEARINGS": RipenessClassifier.MIN_SERVICE_HEARINGS,
59
+ "MIN_STAGE_DAYS": RipenessClassifier.MIN_STAGE_DAYS,
60
+ "MIN_CASE_AGE_DAYS": RipenessClassifier.MIN_CASE_AGE_DAYS,
61
+ }
62
+
63
+ # Check if we have enough data
64
+ if accuracy["completed_predictions"] < 50:
65
+ print("Warning: Insufficient data for calibration (need at least 50 predictions)")
66
+ return adjustments
67
+
68
+ # Rule 1: High false positive rate → increase MIN_SERVICE_HEARINGS
69
+ if accuracy["false_positive_rate"] > cls.HIGH_FALSE_POSITIVE_THRESHOLD:
70
+ current_hearings = current_thresholds.get("MIN_SERVICE_HEARINGS", 1)
71
+ suggested_hearings = current_hearings + 1
72
+ adjustments.append(ThresholdAdjustment(
73
+ threshold_name="MIN_SERVICE_HEARINGS",
74
+ current_value=current_hearings,
75
+ suggested_value=suggested_hearings,
76
+ reason=(
77
+ f"False positive rate {accuracy['false_positive_rate']:.1%} exceeds "
78
+ f"{cls.HIGH_FALSE_POSITIVE_THRESHOLD:.0%}. Cases marked RIPE are adjourning. "
79
+ f"Require more hearings as evidence of readiness."
80
+ ),
81
+ confidence="high",
82
+ ))
83
+
84
+ # Rule 2: High false negative rate → decrease MIN_STAGE_DAYS
85
+ if accuracy["false_negative_rate"] > cls.HIGH_FALSE_NEGATIVE_THRESHOLD:
86
+ current_days = current_thresholds.get("MIN_STAGE_DAYS", 7)
87
+ suggested_days = max(3, current_days - 2) # Don't go below 3 days
88
+ adjustments.append(ThresholdAdjustment(
89
+ threshold_name="MIN_STAGE_DAYS",
90
+ current_value=current_days,
91
+ suggested_value=suggested_days,
92
+ reason=(
93
+ f"False negative rate {accuracy['false_negative_rate']:.1%} exceeds "
94
+ f"{cls.HIGH_FALSE_NEGATIVE_THRESHOLD:.0%}. UNRIPE cases are progressing. "
95
+ f"Relax stage maturity requirement."
96
+ ),
97
+ confidence="medium",
98
+ ))
99
+
100
+ # Rule 3: Low UNKNOWN rate → system too confident, add uncertainty
101
+ if accuracy["unknown_rate"] < cls.LOW_UNKNOWN_THRESHOLD:
102
+ current_age = current_thresholds.get("MIN_CASE_AGE_DAYS", 14)
103
+ suggested_age = current_age + 7
104
+ adjustments.append(ThresholdAdjustment(
105
+ threshold_name="MIN_CASE_AGE_DAYS",
106
+ current_value=current_age,
107
+ suggested_value=suggested_age,
108
+ reason=(
109
+ f"UNKNOWN rate {accuracy['unknown_rate']:.1%} below "
110
+ f"{cls.LOW_UNKNOWN_THRESHOLD:.0%}. System is overconfident. "
111
+ f"Increase case age requirement to add uncertainty for immature cases."
112
+ ),
113
+ confidence="medium",
114
+ ))
115
+
116
+ # Rule 4: Low RIPE precision → more conservative RIPE classification
117
+ if accuracy["ripe_precision"] < cls.LOW_RIPE_PRECISION_THRESHOLD:
118
+ current_hearings = current_thresholds.get("MIN_SERVICE_HEARINGS", 1)
119
+ suggested_hearings = current_hearings + 1
120
+ adjustments.append(ThresholdAdjustment(
121
+ threshold_name="MIN_SERVICE_HEARINGS",
122
+ current_value=current_hearings,
123
+ suggested_value=suggested_hearings,
124
+ reason=(
125
+ f"RIPE precision {accuracy['ripe_precision']:.1%} below "
126
+ f"{cls.LOW_RIPE_PRECISION_THRESHOLD:.0%}. Too many RIPE predictions fail. "
127
+ f"Be more conservative in marking cases RIPE."
128
+ ),
129
+ confidence="high",
130
+ ))
131
+
132
+ # Rule 5: Low UNRIPE recall → missing bottlenecks
133
+ if accuracy["unripe_recall"] < cls.LOW_UNRIPE_RECALL_THRESHOLD:
134
+ current_days = current_thresholds.get("MIN_STAGE_DAYS", 7)
135
+ suggested_days = current_days + 3
136
+ adjustments.append(ThresholdAdjustment(
137
+ threshold_name="MIN_STAGE_DAYS",
138
+ current_value=current_days,
139
+ suggested_value=suggested_days,
140
+ reason=(
141
+ f"UNRIPE recall {accuracy['unripe_recall']:.1%} below "
142
+ f"{cls.LOW_UNRIPE_RECALL_THRESHOLD:.0%}. Missing many bottlenecks. "
143
+ f"Increase stage maturity requirement to catch more unripe cases."
144
+ ),
145
+ confidence="medium",
146
+ ))
147
+
148
+ # Deduplicate adjustments (same threshold suggested multiple times)
149
+ deduplicated = cls._deduplicate_adjustments(adjustments)
150
+
151
+ return deduplicated
152
+
153
+ @classmethod
154
+ def _deduplicate_adjustments(
155
+ cls, adjustments: list[ThresholdAdjustment]
156
+ ) -> list[ThresholdAdjustment]:
157
+ """Deduplicate adjustments for same threshold, prefer high confidence."""
158
+ threshold_map: dict[str, ThresholdAdjustment] = {}
159
+
160
+ for adj in adjustments:
161
+ if adj.threshold_name not in threshold_map:
162
+ threshold_map[adj.threshold_name] = adj
163
+ else:
164
+ # Keep adjustment with higher confidence or larger change
165
+ existing = threshold_map[adj.threshold_name]
166
+ confidence_order = {"high": 3, "medium": 2, "low": 1}
167
+
168
+ if confidence_order[adj.confidence] > confidence_order[existing.confidence]:
169
+ threshold_map[adj.threshold_name] = adj
170
+ elif confidence_order[adj.confidence] == confidence_order[existing.confidence]:
171
+ # Same confidence - keep larger adjustment magnitude
172
+ existing_delta = abs(existing.suggested_value - existing.current_value)
173
+ new_delta = abs(adj.suggested_value - adj.current_value)
174
+ if new_delta > existing_delta:
175
+ threshold_map[adj.threshold_name] = adj
176
+
177
+ return list(threshold_map.values())
178
+
179
+ @classmethod
180
+ def generate_calibration_report(
181
+ cls,
182
+ metrics: RipenessMetrics,
183
+ adjustments: list[ThresholdAdjustment],
184
+ output_path: str | None = None,
185
+ ) -> str:
186
+ """Generate human-readable calibration report.
187
+
188
+ Args:
189
+ metrics: RipenessMetrics with classification history
190
+ adjustments: List of suggested adjustments
191
+ output_path: Optional file path to save report
192
+
193
+ Returns:
194
+ Report text
195
+ """
196
+ accuracy = metrics.get_accuracy_metrics()
197
+
198
+ lines = [
199
+ "Ripeness Classifier Calibration Report",
200
+ "=" * 70,
201
+ "",
202
+ "Current Performance:",
203
+ f" Total predictions: {accuracy['total_predictions']}",
204
+ f" Completed: {accuracy['completed_predictions']}",
205
+ f" False positive rate: {accuracy['false_positive_rate']:.1%}",
206
+ f" False negative rate: {accuracy['false_negative_rate']:.1%}",
207
+ f" UNKNOWN rate: {accuracy['unknown_rate']:.1%}",
208
+ f" RIPE precision: {accuracy['ripe_precision']:.1%}",
209
+ f" UNRIPE recall: {accuracy['unripe_recall']:.1%}",
210
+ "",
211
+ ]
212
+
213
+ if not adjustments:
214
+ lines.extend([
215
+ "Recommended Adjustments:",
216
+ " No adjustments needed - performance is within acceptable ranges.",
217
+ "",
218
+ "Current thresholds are performing well. Continue monitoring.",
219
+ ])
220
+ else:
221
+ lines.extend([
222
+ "Recommended Adjustments:",
223
+ "",
224
+ ])
225
+
226
+ for i, adj in enumerate(adjustments, 1):
227
+ lines.extend([
228
+ f"{i}. {adj.threshold_name}",
229
+ f" Current: {adj.current_value}",
230
+ f" Suggested: {adj.suggested_value}",
231
+ f" Confidence: {adj.confidence.upper()}",
232
+ f" Reason: {adj.reason}",
233
+ "",
234
+ ])
235
+
236
+ lines.extend([
237
+ "Implementation:",
238
+ " 1. Review suggested adjustments",
239
+ " 2. Apply using: RipenessClassifier.set_thresholds(new_values)",
240
+ " 3. Re-run simulation to validate improvements",
241
+ " 4. Compare new metrics with baseline",
242
+ "",
243
+ ])
244
+
245
+ report = "\n".join(lines)
246
+
247
+ if output_path:
248
+ with open(output_path, "w") as f:
249
+ f.write(report)
250
+ print(f"Calibration report saved to {output_path}")
251
+
252
+ return report
253
+
254
+ @classmethod
255
+ def apply_adjustments(
256
+ cls,
257
+ adjustments: list[ThresholdAdjustment],
258
+ auto_apply: bool = False,
259
+ ) -> dict[str, int | float]:
260
+ """Apply threshold adjustments to RipenessClassifier.
261
+
262
+ Args:
263
+ adjustments: List of adjustments to apply
264
+ auto_apply: If True, apply immediately; if False, return dict only
265
+
266
+ Returns:
267
+ Dictionary of new threshold values
268
+ """
269
+ new_thresholds: dict[str, int | float] = {}
270
+
271
+ for adj in adjustments:
272
+ new_thresholds[adj.threshold_name] = adj.suggested_value
273
+
274
+ if auto_apply:
275
+ from scheduler.core.ripeness import RipenessClassifier
276
+ RipenessClassifier.set_thresholds(new_thresholds)
277
+ print(f"Applied {len(adjustments)} threshold adjustments")
278
+
279
+ return new_thresholds
scheduler/monitoring/ripeness_metrics.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ripeness classification accuracy tracking and reporting.
2
+
3
+ Tracks predictions and actual outcomes to measure false positive/negative rates
4
+ and enable data-driven threshold calibration.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import pandas as pd
15
+
16
+ from scheduler.core.ripeness import RipenessStatus
17
+
18
+
19
+ @dataclass
20
+ class RipenessPrediction:
21
+ """Single ripeness classification prediction and outcome."""
22
+
23
+ case_id: str
24
+ predicted_status: RipenessStatus
25
+ prediction_date: datetime
26
+ # Actual outcome (filled in after hearing)
27
+ actual_outcome: Optional[str] = None
28
+ was_adjourned: Optional[bool] = None
29
+ outcome_date: Optional[datetime] = None
30
+
31
+
32
+ class RipenessMetrics:
33
+ """Tracks ripeness classification accuracy for feedback loop calibration."""
34
+
35
+ def __init__(self):
36
+ """Initialize metrics tracker."""
37
+ self.predictions: dict[str, RipenessPrediction] = {}
38
+ self.completed_predictions: list[RipenessPrediction] = []
39
+
40
+ def record_prediction(
41
+ self,
42
+ case_id: str,
43
+ predicted_status: RipenessStatus,
44
+ prediction_date: datetime,
45
+ ) -> None:
46
+ """Record a ripeness classification prediction.
47
+
48
+ Args:
49
+ case_id: Case identifier
50
+ predicted_status: Predicted ripeness status
51
+ prediction_date: When prediction was made
52
+ """
53
+ self.predictions[case_id] = RipenessPrediction(
54
+ case_id=case_id,
55
+ predicted_status=predicted_status,
56
+ prediction_date=prediction_date,
57
+ )
58
+
59
+ def record_outcome(
60
+ self,
61
+ case_id: str,
62
+ actual_outcome: str,
63
+ was_adjourned: bool,
64
+ outcome_date: datetime,
65
+ ) -> None:
66
+ """Record actual hearing outcome for a predicted case.
67
+
68
+ Args:
69
+ case_id: Case identifier
70
+ actual_outcome: Actual hearing outcome (e.g., "ADJOURNED", "ARGUMENTS")
71
+ was_adjourned: Whether hearing was adjourned
72
+ outcome_date: When outcome occurred
73
+ """
74
+ if case_id in self.predictions:
75
+ pred = self.predictions[case_id]
76
+ pred.actual_outcome = actual_outcome
77
+ pred.was_adjourned = was_adjourned
78
+ pred.outcome_date = outcome_date
79
+
80
+ # Move to completed
81
+ self.completed_predictions.append(pred)
82
+ del self.predictions[case_id]
83
+
84
+ def get_accuracy_metrics(self) -> dict[str, float]:
85
+ """Compute classification accuracy metrics.
86
+
87
+ Returns:
88
+ Dictionary with accuracy metrics:
89
+ - total_predictions: Total predictions made
90
+ - completed_predictions: Predictions with outcomes
91
+ - false_positive_rate: RIPE cases that adjourned
92
+ - false_negative_rate: UNRIPE cases that progressed
93
+ - unknown_rate: Cases classified as UNKNOWN
94
+ - ripe_precision: P(progressed | predicted RIPE)
95
+ - unripe_recall: P(predicted UNRIPE | adjourned)
96
+ """
97
+ if not self.completed_predictions:
98
+ return {
99
+ "total_predictions": 0,
100
+ "completed_predictions": 0,
101
+ "false_positive_rate": 0.0,
102
+ "false_negative_rate": 0.0,
103
+ "unknown_rate": 0.0,
104
+ "ripe_precision": 0.0,
105
+ "unripe_recall": 0.0,
106
+ }
107
+
108
+ total = len(self.completed_predictions)
109
+
110
+ # Count predictions by status
111
+ ripe_predictions = [p for p in self.completed_predictions if p.predicted_status == RipenessStatus.RIPE]
112
+ unripe_predictions = [p for p in self.completed_predictions if p.predicted_status.is_unripe()]
113
+ unknown_predictions = [p for p in self.completed_predictions if p.predicted_status == RipenessStatus.UNKNOWN]
114
+
115
+ # Count actual outcomes
116
+ adjourned_cases = [p for p in self.completed_predictions if p.was_adjourned]
117
+ progressed_cases = [p for p in self.completed_predictions if not p.was_adjourned]
118
+
119
+ # False positives: predicted RIPE but adjourned
120
+ false_positives = [p for p in ripe_predictions if p.was_adjourned]
121
+ false_positive_rate = len(false_positives) / len(ripe_predictions) if ripe_predictions else 0.0
122
+
123
+ # False negatives: predicted UNRIPE but progressed
124
+ false_negatives = [p for p in unripe_predictions if not p.was_adjourned]
125
+ false_negative_rate = len(false_negatives) / len(unripe_predictions) if unripe_predictions else 0.0
126
+
127
+ # Precision: of predicted RIPE, how many progressed?
128
+ ripe_correct = [p for p in ripe_predictions if not p.was_adjourned]
129
+ ripe_precision = len(ripe_correct) / len(ripe_predictions) if ripe_predictions else 0.0
130
+
131
+ # Recall: of actually adjourned cases, how many did we predict UNRIPE?
132
+ unripe_correct = [p for p in unripe_predictions if p.was_adjourned]
133
+ unripe_recall = len(unripe_correct) / len(adjourned_cases) if adjourned_cases else 0.0
134
+
135
+ return {
136
+ "total_predictions": total + len(self.predictions),
137
+ "completed_predictions": total,
138
+ "false_positive_rate": false_positive_rate,
139
+ "false_negative_rate": false_negative_rate,
140
+ "unknown_rate": len(unknown_predictions) / total,
141
+ "ripe_precision": ripe_precision,
142
+ "unripe_recall": unripe_recall,
143
+ }
144
+
145
+ def get_confusion_matrix(self) -> dict[str, dict[str, int]]:
146
+ """Generate confusion matrix of predictions vs outcomes.
147
+
148
+ Returns:
149
+ Nested dict: predicted_status -> actual_outcome -> count
150
+ """
151
+ matrix: dict[str, dict[str, int]] = {
152
+ "RIPE": {"progressed": 0, "adjourned": 0},
153
+ "UNRIPE": {"progressed": 0, "adjourned": 0},
154
+ "UNKNOWN": {"progressed": 0, "adjourned": 0},
155
+ }
156
+
157
+ for pred in self.completed_predictions:
158
+ if pred.predicted_status == RipenessStatus.RIPE:
159
+ key = "RIPE"
160
+ elif pred.predicted_status.is_unripe():
161
+ key = "UNRIPE"
162
+ else:
163
+ key = "UNKNOWN"
164
+
165
+ outcome_key = "adjourned" if pred.was_adjourned else "progressed"
166
+ matrix[key][outcome_key] += 1
167
+
168
+ return matrix
169
+
170
+ def to_dataframe(self) -> pd.DataFrame:
171
+ """Export predictions to DataFrame for analysis.
172
+
173
+ Returns:
174
+ DataFrame with columns: case_id, predicted_status, prediction_date,
175
+ actual_outcome, was_adjourned, outcome_date
176
+ """
177
+ records = []
178
+ for pred in self.completed_predictions:
179
+ records.append({
180
+ "case_id": pred.case_id,
181
+ "predicted_status": pred.predicted_status.value,
182
+ "prediction_date": pred.prediction_date,
183
+ "actual_outcome": pred.actual_outcome,
184
+ "was_adjourned": pred.was_adjourned,
185
+ "outcome_date": pred.outcome_date,
186
+ "correct_prediction": (
187
+ (pred.predicted_status == RipenessStatus.RIPE and not pred.was_adjourned)
188
+ or (pred.predicted_status.is_unripe() and pred.was_adjourned)
189
+ ),
190
+ })
191
+
192
+ return pd.DataFrame(records)
193
+
194
+ def save_report(self, output_path: Path) -> None:
195
+ """Save accuracy report and predictions to files.
196
+
197
+ Args:
198
+ output_path: Path to output directory
199
+ """
200
+ output_path.mkdir(parents=True, exist_ok=True)
201
+
202
+ # Save metrics summary
203
+ metrics = self.get_accuracy_metrics()
204
+ metrics_df = pd.DataFrame([metrics])
205
+ metrics_df.to_csv(output_path / "ripeness_accuracy.csv", index=False)
206
+
207
+ # Save confusion matrix
208
+ matrix = self.get_confusion_matrix()
209
+ matrix_df = pd.DataFrame(matrix).T
210
+ matrix_df.to_csv(output_path / "ripeness_confusion_matrix.csv")
211
+
212
+ # Save detailed predictions
213
+ if self.completed_predictions:
214
+ predictions_df = self.to_dataframe()
215
+ predictions_df.to_csv(output_path / "ripeness_predictions.csv", index=False)
216
+
217
+ # Generate human-readable report
218
+ report_lines = [
219
+ "Ripeness Classification Accuracy Report",
220
+ "=" * 60,
221
+ f"Total predictions: {metrics['total_predictions']}",
222
+ f"Completed predictions: {metrics['completed_predictions']}",
223
+ "",
224
+ "Accuracy Metrics:",
225
+ f" False positive rate (RIPE but adjourned): {metrics['false_positive_rate']:.1%}",
226
+ f" False negative rate (UNRIPE but progressed): {metrics['false_negative_rate']:.1%}",
227
+ f" UNKNOWN rate: {metrics['unknown_rate']:.1%}",
228
+ f" RIPE precision (progressed | predicted RIPE): {metrics['ripe_precision']:.1%}",
229
+ f" UNRIPE recall (predicted UNRIPE | adjourned): {metrics['unripe_recall']:.1%}",
230
+ "",
231
+ "Confusion Matrix:",
232
+ f" RIPE -> Progressed: {matrix['RIPE']['progressed']}, Adjourned: {matrix['RIPE']['adjourned']}",
233
+ f" UNRIPE -> Progressed: {matrix['UNRIPE']['progressed']}, Adjourned: {matrix['UNRIPE']['adjourned']}",
234
+ f" UNKNOWN -> Progressed: {matrix['UNKNOWN']['progressed']}, Adjourned: {matrix['UNKNOWN']['adjourned']}",
235
+ "",
236
+ "Interpretation:",
237
+ ]
238
+
239
+ # Add interpretation
240
+ if metrics['false_positive_rate'] > 0.20:
241
+ report_lines.append(" - HIGH false positive rate: Consider increasing MIN_SERVICE_HEARINGS")
242
+ if metrics['false_negative_rate'] > 0.15:
243
+ report_lines.append(" - HIGH false negative rate: Consider decreasing MIN_STAGE_DAYS")
244
+ if metrics['unknown_rate'] < 0.05:
245
+ report_lines.append(" - LOW UNKNOWN rate: System may be overconfident, add uncertainty")
246
+ if metrics['ripe_precision'] > 0.85:
247
+ report_lines.append(" - GOOD RIPE precision: Most RIPE predictions are correct")
248
+ if metrics['unripe_recall'] < 0.60:
249
+ report_lines.append(" - LOW UNRIPE recall: Missing many bottlenecks, refine detection")
250
+
251
+ report_text = "\n".join(report_lines)
252
+ (output_path / "ripeness_report.txt").write_text(report_text)
253
+
254
+ print(f"Ripeness accuracy report saved to {output_path}")
scheduler/optimization/__init__.py DELETED
File without changes
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """EDA pipeline modules."""
src/run_eda.py CHANGED
@@ -1,23 +1,11 @@
1
- """Entrypoint to run the full EDA + parameter pipeline.
 
2
 
3
- Order:
4
- 1. Load & clean (save Parquet + metadata)
5
- 2. Visual EDA (plots + CSV summaries)
6
- 3. Parameter extraction (JSON/CSV priors + features)
7
  """
8
 
9
- from src.eda_exploration import run_exploration
10
- from src.eda_load_clean import run_load_and_clean
11
- from src.eda_parameters import run_parameter_export
12
 
13
  if __name__ == "__main__":
14
- print("Step 1/3: Load and clean")
15
- run_load_and_clean()
16
-
17
- print("\nStep 2/3: Exploratory analysis and plots")
18
- run_exploration()
19
-
20
- print("\nStep 3/3: Parameter extraction for simulation/scheduler")
21
- run_parameter_export()
22
-
23
- print("\nAll steps complete.")
 
1
+ #!/usr/bin/env python
2
+ """Main entry point for Court Scheduling System.
3
 
4
+ This file provides the primary entry point for the project.
5
+ It invokes the CLI which provides all scheduling system operations.
 
 
6
  """
7
 
8
+ from court_scheduler.cli import main
 
 
9
 
10
  if __name__ == "__main__":
11
+ main()
 
 
 
 
 
 
 
 
 
test_enhancements.py → tests/test_enhancements.py RENAMED
File without changes
tests/test_gap_fixes.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test script to validate both gap fixes.
2
+
3
+ Tests:
4
+ 1. Gap 1: RL training uses EDA-derived parameters
5
+ 2. Gap 2: Ripeness feedback loop works
6
+ """
7
+
8
+ from datetime import date, datetime
9
+ from pathlib import Path
10
+
11
+ from scheduler.data.case_generator import CaseGenerator
12
+ from scheduler.data.param_loader import ParameterLoader
13
+ from scheduler.core.ripeness import RipenessClassifier, RipenessStatus
14
+ from scheduler.monitoring.ripeness_metrics import RipenessMetrics
15
+ from scheduler.monitoring.ripeness_calibrator import RipenessCalibrator
16
+ from rl.training import RLTrainingEnvironment, train_agent
17
+ from rl.simple_agent import TabularQAgent
18
+ from rl.config import RLTrainingConfig
19
+
20
+
21
+ def test_gap1_eda_alignment():
22
+ """Test that RL training uses EDA-derived parameters."""
23
+ print("\n" + "=" * 70)
24
+ print("GAP 1: Testing EDA Alignment in RL Training")
25
+ print("=" * 70)
26
+
27
+ # Generate test cases
28
+ generator = CaseGenerator(
29
+ start=date(2024, 1, 1),
30
+ end=date(2024, 1, 31),
31
+ seed=42,
32
+ )
33
+ cases = generator.generate(100, stage_mix_auto=True)
34
+
35
+ # Create environment with param_loader
36
+ env = RLTrainingEnvironment(
37
+ cases=cases,
38
+ start_date=date(2024, 1, 1),
39
+ horizon_days=30,
40
+ )
41
+
42
+ # Verify param_loader exists
43
+ assert hasattr(env, 'param_loader'), "Environment should have param_loader"
44
+ assert isinstance(env.param_loader, ParameterLoader), "param_loader should be ParameterLoader instance"
45
+
46
+ print("✓ ParameterLoader successfully integrated into RLTrainingEnvironment")
47
+
48
+ # Test hearing outcome simulation uses EDA parameters
49
+ test_case = cases[0]
50
+ test_case.current_stage = "ADMISSION"
51
+ test_case.case_type = "RSA"
52
+
53
+ # Get EDA-derived adjournment probability
54
+ p_adj_eda = env.param_loader.get_adjournment_prob("ADMISSION", "RSA")
55
+ print(f"✓ EDA adjournment probability for ADMISSION/RSA: {p_adj_eda:.2%}")
56
+
57
+ # Simulate outcomes multiple times and check alignment
58
+ outcomes = []
59
+ for _ in range(100):
60
+ outcome = env._simulate_hearing_outcome(test_case)
61
+ outcomes.append(outcome)
62
+
63
+ adjourn_rate = sum(1 for o in outcomes if o == "ADJOURNED") / len(outcomes)
64
+ print(f"✓ Simulated adjournment rate: {adjourn_rate:.2%}")
65
+ print(f" Difference from EDA: {abs(adjourn_rate - p_adj_eda):.2%}")
66
+
67
+ # Should be within 15% of EDA value (stochastic sampling)
68
+ assert abs(adjourn_rate - p_adj_eda) < 0.15, f"Adjournment rate {adjourn_rate:.2%} too far from EDA {p_adj_eda:.2%}"
69
+
70
+ print("\n✅ GAP 1 FIXED: RL training now uses EDA-derived parameters\n")
71
+
72
+
73
+ def test_gap2_ripeness_feedback():
74
+ """Test that ripeness feedback loop works."""
75
+ print("\n" + "=" * 70)
76
+ print("GAP 2: Testing Ripeness Feedback Loop")
77
+ print("=" * 70)
78
+
79
+ # Create metrics tracker
80
+ metrics = RipenessMetrics()
81
+
82
+ # Simulate predictions and outcomes (need 50+ for calibrator)
83
+ test_cases = []
84
+
85
+ # Pattern: 50% false positives (RIPE but adjourned), 50% false negatives
86
+ for i in range(50):
87
+ if i % 4 == 0:
88
+ test_cases.append((f"case{i}", RipenessStatus.RIPE, False)) # Correct RIPE
89
+ elif i % 4 == 1:
90
+ test_cases.append((f"case{i}", RipenessStatus.RIPE, True)) # False positive
91
+ elif i % 4 == 2:
92
+ test_cases.append((f"case{i}", RipenessStatus.UNRIPE_SUMMONS, True)) # Correct UNRIPE
93
+ else:
94
+ test_cases.append((f"case{i}", RipenessStatus.UNRIPE_SUMMONS, False)) # False negative
95
+
96
+ prediction_date = datetime(2024, 1, 1)
97
+ outcome_date = datetime(2024, 1, 2)
98
+
99
+ for case_id, predicted_status, was_adjourned in test_cases:
100
+ metrics.record_prediction(case_id, predicted_status, prediction_date)
101
+ actual_outcome = "ADJOURNED" if was_adjourned else "ARGUMENTS"
102
+ metrics.record_outcome(case_id, actual_outcome, was_adjourned, outcome_date)
103
+
104
+ print(f"✓ Recorded {len(test_cases)} predictions and outcomes")
105
+
106
+ # Get accuracy metrics
107
+ accuracy = metrics.get_accuracy_metrics()
108
+ print(f"\n Accuracy Metrics:")
109
+ print(f" False positive rate: {accuracy['false_positive_rate']:.1%}")
110
+ print(f" False negative rate: {accuracy['false_negative_rate']:.1%}")
111
+ print(f" RIPE precision: {accuracy['ripe_precision']:.1%}")
112
+ print(f" UNRIPE recall: {accuracy['unripe_recall']:.1%}")
113
+
114
+ # Expected: 2/4 false positives (50%), 1/2 false negatives (50%)
115
+ assert accuracy['false_positive_rate'] > 0.4, "Should detect false positives"
116
+ assert accuracy['false_negative_rate'] > 0.4, "Should detect false negatives"
117
+
118
+ print("\n✓ RipenessMetrics successfully tracks classification accuracy")
119
+
120
+ # Test calibrator
121
+ adjustments = RipenessCalibrator.analyze_metrics(metrics)
122
+
123
+ print(f"\n✓ RipenessCalibrator generated {len(adjustments)} adjustment suggestions:")
124
+ for adj in adjustments:
125
+ print(f" - {adj.threshold_name}: {adj.current_value} → {adj.suggested_value}")
126
+ print(f" Reason: {adj.reason[:80]}...")
127
+
128
+ assert len(adjustments) > 0, "Should suggest at least one adjustment"
129
+
130
+ # Test threshold configuration
131
+ original_thresholds = RipenessClassifier.get_current_thresholds()
132
+ print(f"\n✓ Current thresholds: {original_thresholds}")
133
+
134
+ # Apply test adjustment
135
+ test_thresholds = {"MIN_SERVICE_HEARINGS": 2}
136
+ RipenessClassifier.set_thresholds(test_thresholds)
137
+
138
+ new_thresholds = RipenessClassifier.get_current_thresholds()
139
+ assert new_thresholds["MIN_SERVICE_HEARINGS"] == 2, "Threshold should be updated"
140
+ print(f"✓ Thresholds successfully updated: {new_thresholds}")
141
+
142
+ # Restore original
143
+ RipenessClassifier.set_thresholds(original_thresholds)
144
+
145
+ print("\n✅ GAP 2 FIXED: Ripeness feedback loop fully operational\n")
146
+
147
+
148
+ def test_end_to_end():
149
+ """Quick end-to-end test with small training run."""
150
+ print("\n" + "=" * 70)
151
+ print("END-TO-END: Testing Both Gaps Together")
152
+ print("=" * 70)
153
+
154
+ # Create agent
155
+ agent = TabularQAgent(learning_rate=0.15, epsilon=0.4, discount=0.95)
156
+
157
+ # Minimal training config
158
+ config = RLTrainingConfig(
159
+ episodes=2,
160
+ episode_length_days=10,
161
+ cases_per_episode=50,
162
+ training_seed=42,
163
+ )
164
+
165
+ print("Running mini training (2 episodes, 50 cases, 10 days)...")
166
+ stats = train_agent(agent, rl_config=config, verbose=False)
167
+
168
+ assert len(stats["episodes"]) == 2, "Should complete 2 episodes"
169
+ assert stats["episodes"][-1] == 1, "Last episode should be episode 1"
170
+
171
+ print(f"✓ Training completed: {len(stats['episodes'])} episodes")
172
+ print(f" Final disposal rate: {stats['disposal_rates'][-1]:.1%}")
173
+ print(f" States explored: {stats['states_explored'][-1]}")
174
+
175
+ print("\n✅ END-TO-END: Both gaps working together successfully\n")
176
+
177
+
178
+ if __name__ == "__main__":
179
+ print("\n" + "=" * 70)
180
+ print("TESTING GAP FIXES")
181
+ print("=" * 70)
182
+
183
+ try:
184
+ test_gap1_eda_alignment()
185
+ test_gap2_ripeness_feedback()
186
+ test_end_to_end()
187
+
188
+ print("\n" + "=" * 70)
189
+ print("ALL TESTS PASSED")
190
+ print("=" * 70)
191
+ print("\nSummary:")
192
+ print(" ✅ Gap 1: RL training aligned with EDA parameters")
193
+ print(" ✅ Gap 2: Ripeness feedback loop operational")
194
+ print(" ✅ End-to-end: Both gaps working together")
195
+ print("\nBoth confirmed gaps are now FIXED!")
196
+ print("=" * 70 + "\n")
197
+
198
+ except Exception as e:
199
+ print(f"\n❌ TEST FAILED: {e}")
200
+ raise
train_rl_agent.py DELETED
@@ -1,238 +0,0 @@
1
- """Configuration-driven RL agent training and evaluation.
2
-
3
- Modular training pipeline for reinforcement learning in court scheduling.
4
- """
5
-
6
- import argparse
7
- import json
8
- import numpy as np
9
- from pathlib import Path
10
- from datetime import date
11
- from dataclasses import dataclass
12
- from typing import Dict, Any
13
-
14
- from rl.simple_agent import TabularQAgent
15
- from rl.training import train_agent, evaluate_agent
16
- from scheduler.data.case_generator import CaseGenerator
17
-
18
-
19
- @dataclass
20
- class TrainingConfig:
21
- """Training configuration parameters."""
22
- episodes: int = 50
23
- cases_per_episode: int = 500
24
- episode_length: int = 30
25
- learning_rate: float = 0.1
26
- initial_epsilon: float = 0.3
27
- discount: float = 0.95
28
- model_name: str = "trained_rl_agent.pkl"
29
-
30
- @classmethod
31
- def from_dict(cls, config_dict: Dict[str, Any]) -> 'TrainingConfig':
32
- """Create config from dictionary."""
33
- return cls(**{k: v for k, v in config_dict.items() if k in cls.__annotations__})
34
-
35
- @classmethod
36
- def from_file(cls, config_path: Path) -> 'TrainingConfig':
37
- """Load config from JSON file."""
38
- with open(config_path) as f:
39
- return cls.from_dict(json.load(f))
40
-
41
-
42
- def run_training_experiment(config: TrainingConfig = None):
43
- """Run configurable RL training experiment.
44
-
45
- Args:
46
- config: Training configuration. If None, uses defaults.
47
- """
48
- if config is None:
49
- config = TrainingConfig()
50
-
51
- print("=" * 70)
52
- print("RL AGENT TRAINING EXPERIMENT")
53
- print("=" * 70)
54
-
55
- print(f"Training Parameters:")
56
- print(f" Episodes: {config.episodes}")
57
- print(f" Cases per episode: {config.cases_per_episode}")
58
- print(f" Episode length: {config.episode_length} days")
59
- print(f" Learning rate: {config.learning_rate}")
60
- print(f" Initial exploration: {config.initial_epsilon}")
61
-
62
- # Initialize agent
63
- agent = TabularQAgent(
64
- learning_rate=config.learning_rate,
65
- epsilon=config.initial_epsilon,
66
- discount=config.discount
67
- )
68
-
69
- print(f"\nInitial agent state: {agent.get_stats()}")
70
-
71
- # Training phase
72
- print("\n" + "=" * 50)
73
- print("TRAINING PHASE")
74
- print("=" * 50)
75
-
76
- training_stats = train_agent(
77
- agent=agent,
78
- episodes=config.episodes,
79
- cases_per_episode=config.cases_per_episode,
80
- episode_length=config.episode_length,
81
- verbose=True
82
- )
83
-
84
- # Save trained agent
85
- model_path = Path("models")
86
- model_path.mkdir(exist_ok=True)
87
- agent_file = model_path / config.model_name
88
- agent.save(agent_file)
89
- print(f"\nTrained agent saved to: {agent_file}")
90
-
91
- # Generate test cases for evaluation
92
- print("\n" + "=" * 50)
93
- print("EVALUATION PHASE")
94
- print("=" * 50)
95
-
96
- test_start = date(2024, 7, 1)
97
- test_end = date(2024, 8, 1)
98
- test_generator = CaseGenerator(start=test_start, end=test_end, seed=999)
99
- test_cases = test_generator.generate(1000, stage_mix_auto=True)
100
-
101
- print(f"Generated {len(test_cases)} test cases")
102
-
103
- # Evaluate trained agent
104
- evaluation_results = evaluate_agent(
105
- agent=agent,
106
- test_cases=test_cases,
107
- episodes=5,
108
- episode_length=60
109
- )
110
-
111
- # Print final analysis
112
- print("\n" + "=" * 50)
113
- print("TRAINING ANALYSIS")
114
- print("=" * 50)
115
-
116
- final_stats = agent.get_stats()
117
- print(f"Final agent statistics:")
118
- print(f" States explored: {final_stats['states_visited']:,}")
119
- print(f" Q-table size: {final_stats['q_table_size']:,}")
120
- print(f" Total Q-updates: {final_stats['total_updates']:,}")
121
- print(f" Final epsilon: {final_stats['epsilon']:.3f}")
122
-
123
- # Training progression analysis
124
- if len(training_stats["disposal_rates"]) >= 10:
125
- early_performance = np.mean(training_stats["disposal_rates"][:10])
126
- late_performance = np.mean(training_stats["disposal_rates"][-10:])
127
- improvement = late_performance - early_performance
128
-
129
- print(f"\nLearning progression:")
130
- print(f" Early episodes (1-10): {early_performance:.1%} disposal rate")
131
- print(f" Late episodes (-10 to end): {late_performance:.1%} disposal rate")
132
- print(f" Improvement: {improvement:.1%}")
133
-
134
- if improvement > 0.01: # 1% improvement threshold
135
- print(" STATUS: Agent showed learning progress")
136
- else:
137
- print(" STATUS: Limited learning detected")
138
-
139
- # State space coverage analysis
140
- theoretical_states = 11 * 10 * 10 * 2 * 2 * 10 # 6D discretized state space
141
- coverage = final_stats['states_visited'] / theoretical_states
142
- print(f"\nState space analysis:")
143
- print(f" Theoretical max states: {theoretical_states:,}")
144
- print(f" States actually visited: {final_stats['states_visited']:,}")
145
- print(f" Coverage: {coverage:.1%}")
146
-
147
- if coverage < 0.01:
148
- print(" WARNING: Very low state space exploration")
149
- elif coverage < 0.1:
150
- print(" NOTE: Limited state space exploration (expected)")
151
- else:
152
- print(" GOOD: Reasonable state space exploration")
153
-
154
- print("\n" + "=" * 50)
155
- print("PERFORMANCE SUMMARY")
156
- print("=" * 50)
157
-
158
- print(f"Trained RL Agent Performance:")
159
- print(f" Mean disposal rate: {evaluation_results['mean_disposal_rate']:.1%}")
160
- print(f" Standard deviation: {evaluation_results['std_disposal_rate']:.1%}")
161
- print(f" Mean utilization: {evaluation_results['mean_utilization']:.1%}")
162
- print(f" Avg hearings to disposal: {evaluation_results['mean_hearings_to_disposal']:.1f}")
163
-
164
- # Compare with baseline from previous runs (known values)
165
- baseline_disposal = 0.107 # 10.7% from readiness policy
166
- rl_disposal = evaluation_results['mean_disposal_rate']
167
-
168
- print(f"\nComparison with Baseline:")
169
- print(f" Baseline (Readiness): {baseline_disposal:.1%}")
170
- print(f" RL Agent: {rl_disposal:.1%}")
171
- print(f" Difference: {(rl_disposal - baseline_disposal):.1%}")
172
-
173
- if rl_disposal > baseline_disposal + 0.01: # 1% improvement threshold
174
- print(" RESULT: RL agent outperforms baseline")
175
- elif rl_disposal > baseline_disposal - 0.01:
176
- print(" RESULT: RL agent performs comparably to baseline")
177
- else:
178
- print(" RESULT: RL agent underperforms baseline")
179
-
180
- # Recommendations
181
- print("\n" + "=" * 50)
182
- print("RECOMMENDATIONS")
183
- print("=" * 50)
184
-
185
- if coverage < 0.01:
186
- print("1. Increase training episodes for better state exploration")
187
- print("2. Consider state space dimensionality reduction")
188
-
189
- if final_stats['total_updates'] < 10000:
190
- print("3. Extend training duration for more Q-value updates")
191
-
192
- if evaluation_results['std_disposal_rate'] > 0.05:
193
- print("4. High variance detected - consider ensemble methods")
194
-
195
- if rl_disposal <= baseline_disposal:
196
- print("5. Reward function may need tuning")
197
- print("6. Consider different exploration strategies")
198
- print("7. Baseline policy is already quite effective")
199
-
200
- print("\nExperiment complete.")
201
- return agent, training_stats, evaluation_results
202
-
203
-
204
- def main():
205
- """CLI interface for RL training."""
206
- parser = argparse.ArgumentParser(description="Train RL agent for court scheduling")
207
- parser.add_argument("--config", type=Path, help="Training configuration file (JSON)")
208
- parser.add_argument("--episodes", type=int, help="Number of training episodes")
209
- parser.add_argument("--learning-rate", type=float, help="Learning rate")
210
- parser.add_argument("--epsilon", type=float, help="Initial exploration rate")
211
- parser.add_argument("--model-name", help="Output model filename")
212
-
213
- args = parser.parse_args()
214
-
215
- # Load config
216
- if args.config and args.config.exists():
217
- config = TrainingConfig.from_file(args.config)
218
- print(f"Loaded configuration from {args.config}")
219
- else:
220
- config = TrainingConfig()
221
- print("Using default configuration")
222
-
223
- # Override config with CLI args
224
- if args.episodes:
225
- config.episodes = args.episodes
226
- if args.learning_rate:
227
- config.learning_rate = args.learning_rate
228
- if args.epsilon:
229
- config.initial_epsilon = args.epsilon
230
- if args.model_name:
231
- config.model_name = args.model_name
232
-
233
- # Run training
234
- return run_training_experiment(config)
235
-
236
-
237
- if __name__ == "__main__":
238
- main()